diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 0fc3b074533d..bc0e0cbd85e1 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -22,6 +22,10 @@ class MyType: list_of_tensors: list[torch.Tensor] numpy_array: np.ndarray unrecognized: UnrecognizedType + small_f_contig_tensor: torch.Tensor + large_f_contig_tensor: torch.Tensor + small_non_contig_tensor: torch.Tensor + large_non_contig_tensor: torch.Tensor def test_encode_decode(): @@ -40,6 +44,10 @@ def test_encode_decode(): ], numpy_array=np.arange(512), unrecognized=UnrecognizedType(33), + small_f_contig_tensor=torch.rand(5, 4).t(), + large_f_contig_tensor=torch.rand(1024, 4).t(), + small_non_contig_tensor=torch.rand(2, 4)[:, 1:3], + large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20], ) encoder = MsgpackEncoder() @@ -47,10 +55,10 @@ def test_encode_decode(): encoded = encoder.encode(obj) - # There should be the main buffer + 2 large tensor buffers - # + 1 large numpy array. "large" is <= 256 bytes. + # There should be the main buffer + 4 large tensor buffers + # + 1 large numpy array. "large" is <= 512 bytes. # The two small tensors are encoded inline. - assert len(encoded) == 4 + assert len(encoded) == 6 decoded: MyType = decoder.decode(encoded) @@ -62,7 +70,7 @@ def test_encode_decode(): encoded2 = encoder.encode_into(obj, preallocated) - assert len(encoded2) == 4 + assert len(encoded2) == 6 assert encoded2[0] is preallocated decoded2: MyType = decoder.decode(encoded2) @@ -78,3 +86,9 @@ def assert_equal(obj1: MyType, obj2: MyType): for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors)) assert np.array_equal(obj1.numpy_array, obj2.numpy_array) assert obj1.unrecognized.an_int == obj2.unrecognized.an_int + assert torch.equal(obj1.small_f_contig_tensor, obj2.small_f_contig_tensor) + assert torch.equal(obj1.large_f_contig_tensor, obj2.large_f_contig_tensor) + assert torch.equal(obj1.small_non_contig_tensor, + obj2.small_non_contig_tensor) + assert torch.equal(obj1.large_non_contig_tensor, + obj2.large_non_contig_tensor) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 99b352fdef80..3af6793fde74 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -14,9 +14,10 @@ CUSTOM_TYPE_PICKLE = 1 CUSTOM_TYPE_CLOUDPICKLE = 2 +CUSTOM_TYPE_RAW_VIEW = 3 # TODO calibrate this size -INLINE_BUF_SIZE_THRESHOLD = 256 +MIN_NOCOPY_BUF_SIZE = 512 bytestr = Union[bytes, bytearray, memoryview, zmq.Frame] @@ -76,14 +77,16 @@ def _encode_ndarray( self, obj: np.ndarray ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: assert self.aux_buffers is not None - if not obj.shape or obj.nbytes < INLINE_BUF_SIZE_THRESHOLD: - # Encode small arrays and scalars inline. - data = obj.data + arr_data = obj.data if obj.data.c_contiguous else obj.tobytes() + if not obj.shape or obj.nbytes < MIN_NOCOPY_BUF_SIZE: + # Encode small arrays and scalars inline. Using this extension type + # ensures we can avoid copying when decoding. + data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data) else: - # Otherwise encode index of backing buffer. - obj = np.ascontiguousarray(obj) + # Otherwise encode index of backing buffer to avoid copy. data = len(self.aux_buffers) - self.aux_buffers.append(obj.data) + self.aux_buffers.append(arr_data) + # We serialize the ndarray as a tuple of native types. # The data is either inlined if small, or an index into a list of # backing buffers that we've stashed in `aux_buffers`. @@ -131,6 +134,8 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray: return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) def ext_hook(self, code: int, data: memoryview) -> Any: + if code == CUSTOM_TYPE_RAW_VIEW: + return data if code == CUSTOM_TYPE_PICKLE: return pickle.loads(data) if code == CUSTOM_TYPE_CLOUDPICKLE: