Skip to content

Commit 3b76e02

Browse files
njhilllk-chen
authored andcommitted
[BugFix] Handle non-contiguous tensors properly when serializing (vllm-project#16492)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 18a4aa1 commit 3b76e02

File tree

2 files changed

+30
-11
lines changed

2 files changed

+30
-11
lines changed

tests/v1/test_serial_utils.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ class MyType:
2222
list_of_tensors: list[torch.Tensor]
2323
numpy_array: np.ndarray
2424
unrecognized: UnrecognizedType
25+
small_f_contig_tensor: torch.Tensor
26+
large_f_contig_tensor: torch.Tensor
27+
small_non_contig_tensor: torch.Tensor
28+
large_non_contig_tensor: torch.Tensor
2529

2630

2731
def test_encode_decode():
@@ -40,17 +44,21 @@ def test_encode_decode():
4044
],
4145
numpy_array=np.arange(512),
4246
unrecognized=UnrecognizedType(33),
47+
small_f_contig_tensor=torch.rand(5, 4).t(),
48+
large_f_contig_tensor=torch.rand(1024, 4).t(),
49+
small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
50+
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
4351
)
4452

4553
encoder = MsgpackEncoder()
4654
decoder = MsgpackDecoder(MyType)
4755

4856
encoded = encoder.encode(obj)
4957

50-
# There should be the main buffer + 2 large tensor buffers
51-
# + 1 large numpy array. "large" is <= 256 bytes.
58+
# There should be the main buffer + 4 large tensor buffers
59+
# + 1 large numpy array. "large" is <= 512 bytes.
5260
# The two small tensors are encoded inline.
53-
assert len(encoded) == 4
61+
assert len(encoded) == 6
5462

5563
decoded: MyType = decoder.decode(encoded)
5664

@@ -62,7 +70,7 @@ def test_encode_decode():
6270

6371
encoded2 = encoder.encode_into(obj, preallocated)
6472

65-
assert len(encoded2) == 4
73+
assert len(encoded2) == 6
6674
assert encoded2[0] is preallocated
6775

6876
decoded2: MyType = decoder.decode(encoded2)
@@ -78,3 +86,9 @@ def assert_equal(obj1: MyType, obj2: MyType):
7886
for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors))
7987
assert np.array_equal(obj1.numpy_array, obj2.numpy_array)
8088
assert obj1.unrecognized.an_int == obj2.unrecognized.an_int
89+
assert torch.equal(obj1.small_f_contig_tensor, obj2.small_f_contig_tensor)
90+
assert torch.equal(obj1.large_f_contig_tensor, obj2.large_f_contig_tensor)
91+
assert torch.equal(obj1.small_non_contig_tensor,
92+
obj2.small_non_contig_tensor)
93+
assert torch.equal(obj1.large_non_contig_tensor,
94+
obj2.large_non_contig_tensor)

vllm/v1/serial_utils.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414

1515
CUSTOM_TYPE_PICKLE = 1
1616
CUSTOM_TYPE_CLOUDPICKLE = 2
17+
CUSTOM_TYPE_RAW_VIEW = 3
1718

1819
# TODO calibrate this size
19-
INLINE_BUF_SIZE_THRESHOLD = 256
20+
MIN_NOCOPY_BUF_SIZE = 512
2021

2122
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
2223

@@ -76,14 +77,16 @@ def _encode_ndarray(
7677
self, obj: np.ndarray
7778
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
7879
assert self.aux_buffers is not None
79-
if not obj.shape or obj.nbytes < INLINE_BUF_SIZE_THRESHOLD:
80-
# Encode small arrays and scalars inline.
81-
data = obj.data
80+
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
81+
if not obj.shape or obj.nbytes < MIN_NOCOPY_BUF_SIZE:
82+
# Encode small arrays and scalars inline. Using this extension type
83+
# ensures we can avoid copying when decoding.
84+
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
8285
else:
83-
# Otherwise encode index of backing buffer.
84-
obj = np.ascontiguousarray(obj)
86+
# Otherwise encode index of backing buffer to avoid copy.
8587
data = len(self.aux_buffers)
86-
self.aux_buffers.append(obj.data)
88+
self.aux_buffers.append(arr_data)
89+
8790
# We serialize the ndarray as a tuple of native types.
8891
# The data is either inlined if small, or an index into a list of
8992
# backing buffers that we've stashed in `aux_buffers`.
@@ -131,6 +134,8 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray:
131134
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
132135

133136
def ext_hook(self, code: int, data: memoryview) -> Any:
137+
if code == CUSTOM_TYPE_RAW_VIEW:
138+
return data
134139
if code == CUSTOM_TYPE_PICKLE:
135140
return pickle.loads(data)
136141
if code == CUSTOM_TYPE_CLOUDPICKLE:

0 commit comments

Comments
 (0)