Skip to content

Commit 1d07ac8

Browse files
comaniacweilong.yu
authored andcommitted
[Misc][V1] Fix type in v1 prefix caching (vllm-project#11151)
1 parent 4f0bdeb commit 1d07ac8

File tree

3 files changed

+27
-15
lines changed

3 files changed

+27
-15
lines changed

tests/v1/core/test_prefix_caching.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_prefill():
4949
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
5050
assert manager.block_pool[block_id].block_hash == block_hash
5151
assert manager.block_pool[block_id].ref_cnt == 1
52-
parent_block_hash = block_hash
52+
parent_block_hash = block_hash.hash_value
5353

5454
# Check partial/preallocated block metadata
5555
for block_id in (3, 4):
@@ -360,11 +360,15 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
360360
assert not computed_blocks
361361
# Just ask for 1 block.
362362
blocks = manager.allocate_slots(req, block_size, computed_blocks)
363+
req.num_computed_tokens = block_size
363364
assert len(blocks) == 1 + num_preallocated_blocks
364365

365-
# Append slots to the block.
366-
req.num_computed_tokens = block_size * len(blocks) # Assume all used.
367-
blocks = manager.append_slots(req, block_size) # Append 1 block.
366+
# Assume all computed.
367+
manager.append_slots(req, block_size * (len(blocks) - 1))
368+
req.num_computed_tokens = block_size * len(blocks)
369+
370+
# Append 1 block.
371+
blocks = manager.append_slots(req, block_size)
368372
assert len(blocks) == 1 + num_preallocated_blocks
369373

370374

vllm/v1/core/kv_cache_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,8 @@ def _cache_full_blocks(
375375
prev_block: The previous block in the chain.
376376
"""
377377
# Update the new blocks with the block hashes through the chain.
378-
prev_block_hash = (prev_block.block_hash
379-
if prev_block is not None else None)
378+
prev_block_hash_value = (prev_block.block_hash.hash_value
379+
if prev_block is not None else None)
380380
for i, blk in enumerate(full_blocks):
381381
blk_idx = blk_start_idx + i
382382

@@ -390,10 +390,10 @@ def _cache_full_blocks(
390390
f"{request.request_id}({request})")
391391

392392
# Compute the hash of the current block.
393-
block_hash = hash_block_tokens(prev_block_hash,
393+
block_hash = hash_block_tokens(prev_block_hash_value,
394394
tuple(block_tokens))
395395

396396
# Update and added the full block to the cache.
397397
blk.block_hash = block_hash
398398
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
399-
prev_block_hash = block_hash
399+
prev_block_hash_value = block_hash.hash_value

vllm/v1/core/kv_cache_utils.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
"""KV-Cache Utilities."""
22
from dataclasses import dataclass
3-
from typing import List, Optional, Tuple
3+
from typing import List, NamedTuple, Optional, Tuple
44

55
from vllm.logger import init_logger
66

77
logger = init_logger(__name__)
88

9-
BlockHashType = Tuple[int, Tuple[int]]
9+
10+
class BlockHashType(NamedTuple):
11+
"""Hash value of a block and the token IDs in the block.
12+
The reason we keep a tuple of token IDs is to make sure no hash
13+
collision happens when the hash value is the same.
14+
"""
15+
hash_value: int
16+
token_ids: Tuple[int]
1017

1118

1219
@dataclass
@@ -171,8 +178,8 @@ def hash_block_tokens(parent_block_hash: Optional[int],
171178
The hash value of the block and the token ids in the block.
172179
The entire tuple is used as the hash key of the block.
173180
"""
174-
return (hash(
175-
(parent_block_hash, *curr_block_token_ids)), curr_block_token_ids)
181+
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
182+
curr_block_token_ids)
176183

177184

178185
def hash_request_tokens(block_size: int,
@@ -188,14 +195,15 @@ def hash_request_tokens(block_size: int,
188195
The list of computed hash values.
189196
"""
190197
ret = []
191-
parent_block_hash = None
198+
parent_block_hash_value = None
192199
for start in range(0, len(token_ids), block_size):
193200
end = start + block_size
194201
block_token_ids = tuple(token_ids[start:end])
195202
# Do not hash the block if it is not full.
196203
if len(block_token_ids) < block_size:
197204
break
198-
block_hash = hash_block_tokens(parent_block_hash, block_token_ids)
205+
block_hash = hash_block_tokens(parent_block_hash_value,
206+
block_token_ids)
199207
ret.append(block_hash)
200-
parent_block_hash = block_hash
208+
parent_block_hash_value = block_hash.hash_value
201209
return ret

0 commit comments

Comments
 (0)