22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33import importlib
44from collections .abc import Callable
5+ from typing import Any
56
67import pytest
78import torch
3233 init_none_hash ,
3334 is_kv_cache_spec_uniform ,
3435 make_block_hash_with_group_id ,
36+ tensor_data ,
3537)
3638from vllm .v1 .kv_cache_interface import (
3739 FullAttentionSpec ,
@@ -61,12 +63,13 @@ def _auto_init_hash_fn(request):
6163
6264def make_request (
6365 request_id : str ,
64- prompt_token_ids : list [int ],
66+ prompt_token_ids : list [int ] | None ,
6567 block_size : int = 3 ,
6668 hash_fn : Callable = hash ,
6769 mm_positions : list [PlaceholderRange ] | None = None ,
6870 mm_hashes : list [str ] | None = None ,
6971 cache_salt : str | None = None ,
72+ prompt_embeds : torch .Tensor | None = None ,
7073):
7174 mm_features = []
7275 if mm_positions is not None :
@@ -90,6 +93,7 @@ def make_request(
9093 lora_request = None ,
9194 cache_salt = cache_salt ,
9295 block_hasher = get_request_block_hasher (block_size , hash_fn ),
96+ prompt_embeds = prompt_embeds ,
9397 )
9498
9599
@@ -450,6 +454,52 @@ def test_generate_block_hash_extra_keys_cache_salt():
450454 assert next_mm_idx == 1
451455
452456
457+ def test_generate_block_hash_extra_keys_prompt_embeds ():
458+ prompt_embeds = torch .randn (10 , 3 )
459+ request = make_request (
460+ request_id = "0" ,
461+ prompt_token_ids = None ,
462+ mm_positions = None ,
463+ mm_hashes = None ,
464+ prompt_embeds = prompt_embeds ,
465+ )
466+
467+ # Test with prompt embeds for the first block
468+ extra_keys , _ = generate_block_hash_extra_keys (request , 0 , 5 , 0 )
469+ expected_embeds = prompt_embeds [0 :5 ]
470+ expected_bytes = kv_cache_utils .tensor_data (expected_embeds ).tobytes ()
471+ assert extra_keys == (expected_bytes ,)
472+
473+ # Test with prompt embeds for the second block
474+ extra_keys , _ = generate_block_hash_extra_keys (request , 5 , 10 , 0 )
475+ expected_embeds = prompt_embeds [5 :10 ]
476+ expected_bytes = kv_cache_utils .tensor_data (expected_embeds ).tobytes ()
477+ assert extra_keys == (expected_bytes ,)
478+
479+
480+ def test_generate_block_hash_extra_keys_different_prompt_embeds ():
481+ prompt_embeds1 = torch .randn (10 , 3 )
482+ prompt_embeds2 = torch .randn (10 , 3 )
483+ request1 = make_request (
484+ request_id = "0" ,
485+ prompt_token_ids = None ,
486+ mm_positions = None ,
487+ mm_hashes = None ,
488+ prompt_embeds = prompt_embeds1 ,
489+ )
490+ request2 = make_request (
491+ request_id = "1" ,
492+ prompt_token_ids = None ,
493+ mm_positions = None ,
494+ mm_hashes = None ,
495+ prompt_embeds = prompt_embeds2 ,
496+ )
497+
498+ extra_keys1 , _ = generate_block_hash_extra_keys (request1 , 0 , 5 , 0 )
499+ extra_keys2 , _ = generate_block_hash_extra_keys (request2 , 0 , 5 , 0 )
500+ assert extra_keys1 != extra_keys2
501+
502+
453503def test_generate_block_hash_extra_keys_lora ():
454504 request = make_request (
455505 request_id = "0" ,
@@ -1556,3 +1606,88 @@ def test_merge_mla_spec():
15561606 ]
15571607 with pytest .raises (AssertionError ):
15581608 kv_cache_specs [0 ].merge (kv_cache_specs )
1609+
1610+
1611+ @pytest .mark .parametrize ("hash_fn" , [sha256 , sha256_cbor ])
1612+ def test_request_block_hasher_with_prompt_embeds (hash_fn : Callable [[Any ], bytes ]):
1613+ block_size = 3
1614+ num_tokens = 2 * block_size
1615+ prompt_token_ids = [_ for _ in range (num_tokens )]
1616+ hidden_size = 5
1617+ prompt_embeds = torch .randn ((num_tokens , hidden_size ))
1618+
1619+ request = make_request (
1620+ request_id = "0" ,
1621+ prompt_token_ids = prompt_token_ids ,
1622+ block_size = block_size ,
1623+ hash_fn = hash_fn ,
1624+ prompt_embeds = prompt_embeds ,
1625+ )
1626+
1627+ block_hashes = request .block_hashes
1628+ assert len (block_hashes ) == 2
1629+
1630+ block1_embeds_bytes = tensor_data (prompt_embeds [:block_size ]).tobytes ()
1631+ expected_hash1 = hash_fn (
1632+ (
1633+ kv_cache_utils .NONE_HASH ,
1634+ tuple (prompt_token_ids [:block_size ]),
1635+ (block1_embeds_bytes ,),
1636+ )
1637+ )
1638+ assert block_hashes [0 ] == expected_hash1
1639+
1640+ block2_embeds_bytes = tensor_data (prompt_embeds [block_size :num_tokens ]).tobytes ()
1641+ expected_hash2 = hash_fn (
1642+ (
1643+ block_hashes [0 ],
1644+ tuple (prompt_token_ids [block_size :num_tokens ]),
1645+ (block2_embeds_bytes ,),
1646+ )
1647+ )
1648+ assert block_hashes [1 ] == expected_hash2
1649+
1650+
1651+ @pytest .mark .parametrize ("hash_fn" , [sha256 , sha256_cbor ])
1652+ def test_request_with_prompt_embeds_and_mm_inputs (hash_fn : Callable [[Any ], bytes ]):
1653+ block_size = 3
1654+ num_tokens = 2 * block_size
1655+ prompt_token_ids = [_ for _ in range (num_tokens )]
1656+ hidden_size = 5
1657+ prompt_embeds = torch .randn ((num_tokens , hidden_size ))
1658+
1659+ request = make_request (
1660+ request_id = "0" ,
1661+ prompt_token_ids = prompt_token_ids ,
1662+ block_size = block_size ,
1663+ hash_fn = hash_fn ,
1664+ mm_positions = [
1665+ PlaceholderRange (offset = 0 , length = 3 ),
1666+ PlaceholderRange (offset = 3 , length = 3 ),
1667+ ],
1668+ mm_hashes = ["hash1" , "hash2" ],
1669+ prompt_embeds = prompt_embeds ,
1670+ )
1671+
1672+ block_hashes = request .block_hashes
1673+ assert len (block_hashes ) == 2
1674+
1675+ block1_embeds_bytes = tensor_data (prompt_embeds [:block_size ]).tobytes ()
1676+ expected_hash1 = hash_fn (
1677+ (
1678+ kv_cache_utils .NONE_HASH ,
1679+ tuple (prompt_token_ids [:block_size ]),
1680+ ("hash1" , block1_embeds_bytes ),
1681+ )
1682+ )
1683+ assert block_hashes [0 ] == expected_hash1
1684+
1685+ block2_embeds_bytes = tensor_data (prompt_embeds [block_size :num_tokens ]).tobytes ()
1686+ expected_hash2 = hash_fn (
1687+ (
1688+ block_hashes [0 ],
1689+ tuple (prompt_token_ids [block_size :num_tokens ]),
1690+ ("hash2" , block2_embeds_bytes ),
1691+ )
1692+ )
1693+ assert block_hashes [1 ] == expected_hash2
0 commit comments