Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tests/multimodal/test_hasher.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ def test_hash_collision_image_transpose():
assert hasher.hash_kwargs(image=image1) != hasher.hash_kwargs(image=image2)


def test_hash_collision_tensor_shape():
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_hash_collision_tensor_shape(dtype):
# The hash should be different though the data is the same when flattened
arr1 = torch.zeros((5, 10, 20, 3))
arr2 = torch.zeros((10, 20, 5, 3))
arr1 = torch.zeros((5, 10, 20, 3), dtype=dtype)
arr2 = torch.zeros((10, 20, 5, 3), dtype=dtype)

hasher = MultiModalHasher
assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2)
Expand Down
10 changes: 8 additions & 2 deletions vllm/multimodal/hasher.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,22 @@ def serialize_item(cls, obj: object) -> Union[bytes, memoryview]:
if isinstance(obj, torch.Tensor):
tensor_obj: torch.Tensor = obj.cpu()
tensor_dtype = tensor_obj.dtype
tensor_shape = tensor_obj.shape

# NumPy does not support bfloat16.
# Workaround: View the tensor as a contiguous 1D array of bytes
if tensor_dtype == torch.bfloat16:
tensor_obj = tensor_obj.contiguous()
tensor_obj = tensor_obj.view(
(tensor_obj.numel(), )).view(torch.uint8)

return cls.item_to_bytes(
"tensor", {
"original_dtype": str(tensor_dtype),
"original_shape": tuple(tensor_obj.shape),
"data": tensor_obj.numpy()
"original_shape": tuple(tensor_shape),
"data": tensor_obj.numpy(),
})
Comment on lines 53 to 62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The method used to get the byte representation of a bfloat16 tensor is incorrect and will cause a TypeError at runtime. The torch.Tensor.view() method expects a shape as input, not a dtype. Additionally, a bfloat16 tensor (2 bytes per element) cannot be directly viewed as a uint8 tensor (1 byte per element) by reinterpreting memory this way.

A more robust approach is to first view the bfloat16 tensor as a type of the same size (e.g., uint16), then convert it to a NumPy array, and finally view that array as uint8 to get the raw bytes for hashing.

Suggested change
tensor_obj = tensor_obj.contiguous()
tensor_obj = tensor_obj.view(
(tensor_obj.numel(), )).view(torch.uint8)
return cls.item_to_bytes(
"tensor", {
"original_dtype": str(tensor_dtype),
"original_shape": tuple(tensor_obj.shape),
"data": tensor_obj.numpy()
"original_shape": tuple(tensor_shape),
"data": tensor_obj.numpy(),
})
tensor_obj = tensor_obj.contiguous()
# To correctly get the byte representation of a bfloat16 tensor,
# it should be viewed as a type of the same size (e.g., uint16),
# then converted to a numpy array, and finally viewed as uint8.
data_np = tensor_obj.view(torch.uint16).numpy().view(np.uint8)
return cls.item_to_bytes(
"tensor", {
"original_dtype": str(tensor_dtype),
"original_shape": tuple(tensor_shape),
"data": data_np,
})


return cls.item_to_bytes("tensor", tensor_obj.numpy())
if isinstance(obj, np.ndarray):
# If the array is non-contiguous, we need to copy it first
Expand Down