Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions tests/v1/test_serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class MyType:
list_of_tensors: list[torch.Tensor]
numpy_array: np.ndarray
unrecognized: UnrecognizedType
small_f_contig_tensor: torch.Tensor
large_f_contig_tensor: torch.Tensor
small_non_contig_tensor: torch.Tensor
large_non_contig_tensor: torch.Tensor


def test_encode_decode():
Expand All @@ -40,17 +44,21 @@ def test_encode_decode():
],
numpy_array=np.arange(512),
unrecognized=UnrecognizedType(33),
small_f_contig_tensor=torch.rand(5, 4).t(),
large_f_contig_tensor=torch.rand(1024, 4).t(),
small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
)

encoder = MsgpackEncoder()
decoder = MsgpackDecoder(MyType)

encoded = encoder.encode(obj)

# There should be the main buffer + 2 large tensor buffers
# + 1 large numpy array. "large" is <= 256 bytes.
# There should be the main buffer + 4 large tensor buffers
# + 1 large numpy array. "large" is <= 512 bytes.
# The two small tensors are encoded inline.
assert len(encoded) == 4
assert len(encoded) == 6

decoded: MyType = decoder.decode(encoded)

Expand All @@ -62,7 +70,7 @@ def test_encode_decode():

encoded2 = encoder.encode_into(obj, preallocated)

assert len(encoded2) == 4
assert len(encoded2) == 6
assert encoded2[0] is preallocated

decoded2: MyType = decoder.decode(encoded2)
Expand All @@ -78,3 +86,9 @@ def assert_equal(obj1: MyType, obj2: MyType):
for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors))
assert np.array_equal(obj1.numpy_array, obj2.numpy_array)
assert obj1.unrecognized.an_int == obj2.unrecognized.an_int
assert torch.equal(obj1.small_f_contig_tensor, obj2.small_f_contig_tensor)
assert torch.equal(obj1.large_f_contig_tensor, obj2.large_f_contig_tensor)
assert torch.equal(obj1.small_non_contig_tensor,
obj2.small_non_contig_tensor)
assert torch.equal(obj1.large_non_contig_tensor,
obj2.large_non_contig_tensor)
19 changes: 12 additions & 7 deletions vllm/v1/serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
CUSTOM_TYPE_RAW_VIEW = 3

# TODO calibrate this size
INLINE_BUF_SIZE_THRESHOLD = 256
MIN_NOCOPY_BUF_SIZE = 512

bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]

Expand Down Expand Up @@ -76,14 +77,16 @@ def _encode_ndarray(
self, obj: np.ndarray
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None
if not obj.shape or obj.nbytes < INLINE_BUF_SIZE_THRESHOLD:
# Encode small arrays and scalars inline.
data = obj.data
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
if not obj.shape or obj.nbytes < MIN_NOCOPY_BUF_SIZE:
# Encode small arrays and scalars inline. Using this extension type
# ensures we can avoid copying when decoding.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
else:
# Otherwise encode index of backing buffer.
obj = np.ascontiguousarray(obj)
# Otherwise encode index of backing buffer to avoid copy.
data = len(self.aux_buffers)
self.aux_buffers.append(obj.data)
self.aux_buffers.append(arr_data)

# We serialize the ndarray as a tuple of native types.
# The data is either inlined if small, or an index into a list of
# backing buffers that we've stashed in `aux_buffers`.
Expand Down Expand Up @@ -131,6 +134,8 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray:
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)

def ext_hook(self, code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_RAW_VIEW:
return data
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE:
Expand Down