Skip to content

Commit

Permalink
[v1][Bugfix] Add extra_keys to block_hash for prefix caching (vllm-pr…
Browse files Browse the repository at this point in the history
…oject#12603)

This pr adds extra key to block hash, to generate different hash value
for two blocks with the same token string but different extra_keys in
their parent blocks. For example, it can generate different hash value
for the second block of the following two requests:
```python
request1 = make_request(
        request_id=0,
        prompt_token_ids=[_ for _ in range(6)],
        mm_positions=[{
            "offset": 0,
            "length": 3
        }, {
            "offset": 3,
            "length": 3
        }],
        mm_hashes=["hash1", "hash2"],
    )
    request2 = make_request(
        request_id=1,
        prompt_token_ids=[_ for _ in range(6)],
        mm_positions=[{
            "offset": 0,
            "length": 3
        }, {
            "offset": 3,
            "length": 3
        }],
        mm_hashes=["hash3", "hash2"],
    )
```

---------

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
  • Loading branch information
heheda12345 authored and Isotr0py committed Feb 2, 2025
1 parent c2f4e2b commit 0e76791
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
34 changes: 33 additions & 1 deletion tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def test_hash_block_tokens():
extra_keys)
assert isinstance(block_hash, BlockHashType)
assert block_hash.hash_value == hash(
(parent_block_hash, *curr_block_token_ids))
(parent_block_hash, curr_block_token_ids, extra_keys))
assert block_hash.token_ids == curr_block_token_ids
assert block_hash.extra_keys == extra_keys

Expand Down Expand Up @@ -227,6 +227,38 @@ def test_hash_request_tokens():
assert block_hashes[1].extra_keys == ("hash2", )


def test_hash_tokens_different_mm_input():
request1 = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(6)],
mm_positions=[{
"offset": 0,
"length": 3
}, {
"offset": 3,
"length": 3
}],
mm_hashes=["hash1", "hash2"],
)
request2 = make_request(
request_id=1,
prompt_token_ids=[_ for _ in range(6)],
mm_positions=[{
"offset": 0,
"length": 3
}, {
"offset": 3,
"length": 3
}],
mm_hashes=["hash3", "hash2"],
)
block_size = 3
block_hashes1 = hash_request_tokens(block_size, request1)
block_hashes2 = hash_request_tokens(block_size, request2)
assert block_hashes1[0] != block_hashes2[0]
assert block_hashes1[1] != block_hashes2[1]


def test_hash_request_tokens_no_mm_inputs():
request = make_request(
request_id=0,
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,10 @@ def hash_block_tokens(
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
"""
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
tuple(curr_block_token_ids), extra_keys)
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
return BlockHashType(
hash((parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
curr_block_token_ids_tuple, extra_keys)


def hash_request_tokens(block_size: int,
Expand Down

0 comments on commit 0e76791

Please sign in to comment.