Skip to content

Commit 393c8f3

Browse files
committed
[Core] Use tuple for kv cache group block ids
IMO tuple is a better fit than list here since there will be a fixed number of elements corresponding to the number of kv cache groups. It's also more efficient than using a list performance-wise. As part of this refactoring I've included a few other adjacent code simplifications. Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 441b65d commit 393c8f3

File tree

12 files changed

+125
-142
lines changed

12 files changed

+125
-142
lines changed

tests/v1/core/test_prefix_caching.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_prefill(hash_algo):
117117
blocks = manager.allocate_slots(req0, 55,
118118
len(computed_blocks.blocks[0]) * 16,
119119
computed_blocks)
120-
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
120+
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
121121

122122
# Check full block metadata
123123
parent_block_hash = None
@@ -141,14 +141,14 @@ def test_prefill(hash_algo):
141141
req1 = make_request("1", common_token_ids + unique_token_ids)
142142
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
143143
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
144-
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
144+
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
145145
assert num_computed_tokens == 3 * 16
146146
num_new_tokens = 53 - 3 * 16
147147
blocks = manager.allocate_slots(req1, num_new_tokens,
148148
len(computed_blocks.blocks[0]) * 16,
149149
computed_blocks)
150-
assert blocks.get_block_ids() == [[5]]
151-
for block in computed_blocks.blocks[0]:
150+
assert blocks.get_block_ids() == ([5], )
151+
for (block, ) in computed_blocks.blocks:
152152
assert block.ref_cnt == 2
153153

154154
# At this point, we should have 5 free blocks left.
@@ -175,13 +175,13 @@ def test_prefill(hash_algo):
175175
req2 = make_request("2", common_token_ids + unique_token_ids)
176176
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
177177
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
178-
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
178+
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
179179
assert num_computed_tokens == 3 * 16
180180
num_new_tokens = 53 - 3 * 16
181181
blocks = manager.allocate_slots(req2, num_new_tokens,
182182
len(computed_blocks.blocks[0]) * 16,
183183
computed_blocks)
184-
assert blocks.get_block_ids() == [[6]]
184+
assert blocks.get_block_ids() == ([6], )
185185

186186
# Although we only have 6 free blocks, we have 8 blocks in
187187
# the free block queue due to lazy removal.
@@ -205,7 +205,7 @@ def test_prefill(hash_algo):
205205
len(computed_blocks.blocks[0]) * 16,
206206
computed_blocks)
207207
# This block ID order also checks the eviction order.
208-
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
208+
assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
209209
assert manager.block_pool.free_block_queue.num_free_blocks == 0
210210
assert manager.block_pool.free_block_queue.free_list_head is None
211211
assert manager.block_pool.free_block_queue.free_list_tail is None
@@ -236,8 +236,8 @@ def test_prefill_hybrid_model():
236236
blocks = manager.allocate_slots(req0, 55,
237237
len(computed_blocks.blocks[0]) * 16,
238238
computed_blocks)
239-
assert blocks.get_block_ids() == [[1, 2, 3, 4], [5, 6, 7, 8],
240-
[9, 10, 11, 12]]
239+
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7,
240+
8], [9, 10, 11, 12])
241241

242242
# Check full block metadata
243243
parent_block_hash = None
@@ -263,14 +263,14 @@ def test_prefill_hybrid_model():
263263
req1 = make_request("1", common_token_ids + unique_token_ids)
264264
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
265265
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
266-
assert computed_blocks.get_block_ids() == [[1, 2, 3], [0, 6, 7],
267-
[0, 10, 11]]
266+
assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6,
267+
7], [0, 10, 11])
268268
assert num_computed_tokens == 3 * 16
269269
num_new_tokens = 53 - 3 * 16
270270
blocks = manager.allocate_slots(req1, num_new_tokens,
271271
len(computed_blocks.blocks[0]) * 16,
272272
computed_blocks)
273-
assert blocks.get_block_ids() == [[13], [14], [15]]
273+
assert blocks.get_block_ids() == ([13], [14], [15])
274274
for block_per_group in computed_blocks.blocks:
275275
for block in block_per_group:
276276
if block != manager.block_pool.null_block:
@@ -374,8 +374,8 @@ def test_prefill_plp():
374374
blocks = manager.allocate_slots(req0, 55,
375375
len(computed_blocks.blocks[0]) * 16,
376376
computed_blocks)
377-
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
378-
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
377+
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
378+
req0_block_hashes = [b.block_hash for b in blocks.blocks]
379379

380380
# Check full block metadata
381381
parent_block_hash = None
@@ -400,14 +400,14 @@ def test_prefill_plp():
400400
req1 = make_request("1", common_token_ids + unique_token_ids)
401401
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
402402
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
403-
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
403+
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
404404
assert num_computed_tokens == 3 * 16
405405
num_new_tokens = 53 - 3 * 16
406406
blocks = manager.allocate_slots(req1, num_new_tokens,
407407
len(computed_blocks.blocks[0]) * 16,
408408
computed_blocks)
409-
assert blocks.get_block_ids() == [[5]]
410-
for block in computed_blocks.blocks[0]:
409+
assert blocks.get_block_ids() == ([5], )
410+
for block in computed_blocks.blocks:
411411
assert block.ref_cnt == 2
412412

413413
# At this point, we should have 5 free blocks left.
@@ -443,8 +443,8 @@ def test_prefill_plp():
443443
computed_blocks)
444444
block_ids = blocks.get_block_ids()
445445
# Duplicate cached blocks have different ids but same hashes vs request #0
446-
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
447-
assert block_ids != [[1, 2, 3, 4]]
446+
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
447+
assert block_ids != ([1, 2, 3, 4], )
448448

449449
# Request #2 block hashes are valid since request #0 hashes are.
450450
# Check block reference counts.
@@ -474,7 +474,7 @@ def test_decode():
474474
blocks = manager.allocate_slots(req0, 55,
475475
len(computed_blocks.blocks[0]) * 16,
476476
computed_blocks)
477-
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
477+
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
478478

479479
# Append slots without allocating a new block.
480480
req0.num_computed_tokens = 55
@@ -546,12 +546,12 @@ def test_evict():
546546
# Touch the first 2 blocks.
547547
req2 = make_request("2", list(range(2 * 16 + 3)))
548548
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
549-
assert computed_blocks.get_block_ids() == [[1, 2]]
549+
assert computed_blocks.get_block_ids() == ([1, 2], )
550550
assert num_computed_tokens == 2 * 16
551551
blocks = manager.allocate_slots(req2, 3,
552552
len(computed_blocks.blocks[0]) * 16,
553553
computed_blocks)
554-
assert blocks.get_block_ids() == [[10]]
554+
assert blocks.get_block_ids() == ([10], )
555555
assert manager.block_pool.free_block_queue.num_free_blocks == 7
556556

557557

@@ -865,7 +865,7 @@ def test_mm_prefix_caching():
865865
blocks = manager.allocate_slots(req0, 59,
866866
len(computed_blocks.blocks[0]) * 16,
867867
computed_blocks)
868-
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
868+
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
869869
req0.num_computed_tokens = 59
870870

871871
# Append slots without allocating a new block.
@@ -926,7 +926,7 @@ def test_cache_key_salting():
926926
blocks = manager.allocate_slots(req0, 59,
927927
len(computed_blocks.blocks[0]) * 16,
928928
computed_blocks)
929-
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
929+
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
930930
req0.num_computed_tokens = 59
931931

932932
# Append slots without allocating a new block.
@@ -1042,7 +1042,7 @@ def test_reset_prefix_cache():
10421042
all_token_ids = full_block_token_ids + unique_token_ids
10431043
req0 = make_request("0", all_token_ids)
10441044
blocks = manager.allocate_slots(req0, 55)
1045-
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
1045+
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
10461046

10471047
unique_token_ids = [4] * 7
10481048
all_token_ids = full_block_token_ids + unique_token_ids
@@ -1053,7 +1053,7 @@ def test_reset_prefix_cache():
10531053
blocks = manager.allocate_slots(req1, 7,
10541054
len(computed_blocks.blocks[0]) * 16,
10551055
computed_blocks)
1056-
assert blocks.get_block_ids() == [[5]]
1056+
assert blocks.get_block_ids() == ([5], )
10571057

10581058
# Failed to reset prefix cache because some blocks are not freed yet.
10591059
assert not manager.reset_prefix_cache()

tests/v1/tpu/worker/test_tpu_model_runner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
8787
mm_hashes=[],
8888
mm_positions=[],
8989
sampling_params=SamplingParams(),
90-
block_ids=[[0]], # block_ids should be list[list[int]]
90+
block_ids=([0], ), # block_ids should be tuple[list[int]]
9191
num_computed_tokens=0,
9292
lora_request=None,
9393
))
@@ -132,10 +132,10 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
132132
# This is safe since we currently only use single KV cache groups
133133
block_table = multi_group_block_table[0]
134134

135-
# req_state.block_ids is now list[list[int]] for MultiGroupBlockTable
135+
# req_state.block_ids is now tuple[list[int], ...] for MultiGroupBlockTable
136136
# Extract the first group's block IDs
137137
if isinstance(req_state.block_ids[0], list):
138-
# New format: list[list[int]] - extract first group
138+
# New format: tuple[list[int], ...] - extract first group
139139
req_block_ids = req_state.block_ids[0]
140140
else:
141141
# Legacy format: list[int] - use directly
@@ -226,7 +226,7 @@ def test_update_states_request_resumed(model_runner):
226226
req_id=req_id,
227227
resumed_from_preemption=False,
228228
new_token_ids=[],
229-
new_block_ids=[[]],
229+
new_block_ids=([], ),
230230
num_computed_tokens=0,
231231
)
232232

tests/v1/worker/test_gpu_input_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
203203
sampling_params=_create_sampling_params(),
204204
mm_inputs=[],
205205
mm_positions=[],
206-
block_ids=[[]],
206+
block_ids=([], ),
207207
generator=None,
208208
num_computed_tokens=len(output_token_ids),
209209
output_token_ids=output_token_ids,

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
122122
mm_hashes=[],
123123
mm_positions=[],
124124
sampling_params=SamplingParams(),
125-
block_ids=[[0]],
125+
block_ids=([0], ),
126126
num_computed_tokens=0,
127127
lora_request=None,
128128
))
@@ -250,7 +250,7 @@ def test_update_states_request_resumed(model_runner):
250250
req_id=req_id,
251251
resumed_from_preemption=False,
252252
new_token_ids=[],
253-
new_block_ids=[[]],
253+
new_block_ids=([], ),
254254
num_computed_tokens=0,
255255
)
256256

vllm/v1/core/block_pool.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def get_cached_block(
8989
BlockHashWithGroupId(block_hash, group_id))
9090
if not cached_blocks_one_group:
9191
return None
92-
first_block_id = next(iter(cached_blocks_one_group))
93-
cached_blocks.append(cached_blocks_one_group[first_block_id])
92+
first_block = next(iter(cached_blocks_one_group.values()))
93+
cached_blocks.append(first_block)
9494
return cached_blocks
9595

9696
def cache_full_blocks(
@@ -260,7 +260,7 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
260260
return True
261261
return False
262262

263-
def touch(self, blocks: list[list[KVCacheBlock]]) -> None:
263+
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
264264
"""Touch a block increases its reference count by 1, and may remove
265265
the block from the free queue. This is used when a block is hit by
266266
another request with the same prefix.
@@ -299,7 +299,7 @@ def reset_prefix_cache(self) -> bool:
299299
bool: True if the prefix cache is successfully reset,
300300
False otherwise.
301301
"""
302-
num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks())
302+
num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks()
303303
if num_used_blocks != 1: # The null block is always marked as used
304304
logger.warning(
305305
"Failed to reset prefix cache because some "

0 commit comments

Comments
 (0)