Skip to content

Commit cfe4c15

Browse files
DarkLight1337lgeiger
authored andcommitted
[Bugfix] Fix incorrect original shape in hashing (vllm-project#23672)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Lukas Geiger <lukas.geiger94@gmail.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
1 parent 1e84fbf commit cfe4c15

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

tests/multimodal/test_hasher.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,11 @@ def test_hash_collision_image_transpose():
4545
assert hasher.hash_kwargs(image=image1) != hasher.hash_kwargs(image=image2)
4646

4747

48-
def test_hash_collision_tensor_shape():
48+
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
49+
def test_hash_collision_tensor_shape(dtype):
4950
# The hash should be different though the data is the same when flattened
50-
arr1 = torch.zeros((5, 10, 20, 3))
51-
arr2 = torch.zeros((10, 20, 5, 3))
51+
arr1 = torch.zeros((5, 10, 20, 3), dtype=dtype)
52+
arr2 = torch.zeros((10, 20, 5, 3), dtype=dtype)
5253

5354
hasher = MultiModalHasher
5455
assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2)

vllm/multimodal/hasher.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,22 @@ def serialize_item(cls, obj: object) -> Union[bytes, memoryview]:
4545
if isinstance(obj, torch.Tensor):
4646
tensor_obj: torch.Tensor = obj.cpu()
4747
tensor_dtype = tensor_obj.dtype
48+
tensor_shape = tensor_obj.shape
49+
50+
# NumPy does not support bfloat16.
51+
# Workaround: View the tensor as a contiguous 1D array of bytes
4852
if tensor_dtype == torch.bfloat16:
4953
tensor_obj = tensor_obj.contiguous()
5054
tensor_obj = tensor_obj.view(
5155
(tensor_obj.numel(), )).view(torch.uint8)
56+
5257
return cls.item_to_bytes(
5358
"tensor", {
5459
"original_dtype": str(tensor_dtype),
55-
"original_shape": tuple(tensor_obj.shape),
56-
"data": tensor_obj.numpy()
60+
"original_shape": tuple(tensor_shape),
61+
"data": tensor_obj.numpy(),
5762
})
63+
5864
return cls.item_to_bytes("tensor", tensor_obj.numpy())
5965
if isinstance(obj, np.ndarray):
6066
# If the array is non-contiguous, we need to copy it first

0 commit comments

Comments
 (0)