diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index a03810625466..2d7411381e16 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -81,7 +81,9 @@ def test_prefill(hash_algo): assert len(manager.req_to_block_hashes[req0.request_id]) == 3 assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, computed_blocks) + blocks = manager.allocate_slots(req0, 55, + len(computed_blocks.blocks) * 16, + computed_blocks) assert blocks.get_block_ids() == [1, 2, 3, 4] # Check full block metadata @@ -108,7 +110,9 @@ def test_prefill(hash_algo): assert computed_blocks.get_block_ids() == [1, 2, 3] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) + blocks = manager.allocate_slots(req1, num_new_tokens, + len(computed_blocks.blocks) * 16, + computed_blocks) assert blocks.get_block_ids() == [5] for block in computed_blocks.blocks: assert block.ref_cnt == 2 @@ -140,7 +144,9 @@ def test_prefill(hash_algo): assert computed_blocks.get_block_ids() == [1, 2, 3] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) + blocks = manager.allocate_slots(req2, num_new_tokens, + len(computed_blocks.blocks) * 16, + computed_blocks) assert blocks.get_block_ids() == [6] # Although we only have 6 free blocks, we have 8 blocks in @@ -161,7 +167,9 @@ def test_prefill(hash_algo): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks) + blocks = manager.allocate_slots(req3, 16 * 10, + len(computed_blocks.blocks) * 16, + computed_blocks) # This block ID order also checks the eviction order. assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1] assert manager.block_pool.free_block_queue.num_free_blocks == 0 @@ -197,7 +205,9 @@ def test_prefill_plp(): assert len(manager.req_to_block_hashes[req0.request_id]) == 0 assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, computed_blocks) + blocks = manager.allocate_slots(req0, 55, + len(computed_blocks.blocks) * 16, + computed_blocks) assert blocks.get_block_ids() == [1, 2, 3, 4] req0_block_hashes = [b.block_hash for b in blocks.blocks] @@ -226,7 +236,9 @@ def test_prefill_plp(): assert computed_blocks.get_block_ids() == [1, 2, 3] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) + blocks = manager.allocate_slots(req1, num_new_tokens, + len(computed_blocks.blocks) * 16, + computed_blocks) assert blocks.get_block_ids() == [5] for block in computed_blocks.blocks: assert block.ref_cnt == 2 @@ -259,7 +271,9 @@ def test_prefill_plp(): assert len(manager.req_to_block_hashes[req2.request_id]) == 0 assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req2, 55, computed_blocks) + blocks = manager.allocate_slots(req2, 55, + len(computed_blocks.blocks) * 16, + computed_blocks) block_ids = blocks.get_block_ids() # Duplicate cached blocks have different ids but same hashes vs request #0 assert [b.block_hash for b in blocks.blocks] == req0_block_hashes @@ -290,14 +304,18 @@ def test_decode(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, computed_blocks) + blocks = manager.allocate_slots(req0, 55, + len(computed_blocks.blocks) * 16, + computed_blocks) assert blocks.get_block_ids() == [1, 2, 3, 4] # Append slots without allocating a new block. req0.num_computed_tokens = 55 for _ in range(4): req0.append_output_token_ids(8) - new_blocks = manager.allocate_slots(req0, 4) + new_blocks = manager.allocate_slots(req0, 4, + len(computed_blocks.blocks) * 16, + computed_blocks) assert new_blocks is not None and len(new_blocks.blocks) == 0 assert manager.single_type_manager.req_to_blocks[ req0.request_id][-1].block_hash is None @@ -308,7 +326,9 @@ def test_decode(): # the preallocated block. for _ in range(9 + 10): req0.append_output_token_ids(7) - new_blocks = manager.allocate_slots(req0, 19) + new_blocks = manager.allocate_slots(req0, 19, + len(computed_blocks.blocks) * 16, + computed_blocks) assert new_blocks is not None and len(new_blocks.blocks) == 1 assert manager.single_type_manager.req_to_blocks[ req0.request_id][-2].block_hash is not None @@ -328,7 +348,9 @@ def test_evict(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) + blocks = manager.allocate_slots(req0, 5 * 16 + 7, + len(computed_blocks.blocks) * 16, + computed_blocks) assert len(blocks.blocks) == 6 # 5 full + 1 partial # 3 blocks. @@ -337,7 +359,9 @@ def test_evict(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks) + blocks = manager.allocate_slots(req1, 3 * 16, + len(computed_blocks.blocks) * 16, + computed_blocks) assert len(blocks.blocks) == 3 # 3 full blocks last_token_id += 3 * 16 @@ -357,7 +381,9 @@ def test_evict(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert computed_blocks.get_block_ids() == [1, 2] assert num_computed_tokens == 2 * 16 - blocks = manager.allocate_slots(req2, 3, computed_blocks) + blocks = manager.allocate_slots(req2, 3, + len(computed_blocks.blocks) * 16, + computed_blocks) assert blocks.get_block_ids() == [10] assert manager.block_pool.free_block_queue.num_free_blocks == 7 @@ -380,7 +406,9 @@ def test_hash_block_correct_reuse(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req, num_tokens, computed_blocks) + blocks = manager.allocate_slots(req, num_tokens, + len(computed_blocks.blocks) * 16, + computed_blocks) assert len(blocks.blocks) == 1 # Deallocate the block. @@ -392,7 +420,9 @@ def test_hash_block_correct_reuse(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks) + blocks = manager.allocate_slots(req, num_tokens - 1, + len(computed_blocks.blocks) * 16, + computed_blocks) assert len(blocks.blocks) == 1 assert manager.block_pool.blocks[ @@ -417,7 +447,9 @@ def test_computed_blocks_not_evicted(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, num_tokens, computed_blocks) + blocks = manager.allocate_slots(req0, num_tokens, + len(computed_blocks.blocks) * 16, + computed_blocks) assert len(blocks.blocks) == 1 assert blocks.blocks[0].block_id == 1 @@ -426,7 +458,9 @@ def test_computed_blocks_not_evicted(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, num_tokens, computed_blocks) + blocks = manager.allocate_slots(req1, num_tokens, + len(computed_blocks.blocks) * 16, + computed_blocks) assert len(blocks.blocks) == 1 assert blocks.blocks[0].block_id == 2 @@ -443,6 +477,7 @@ def test_computed_blocks_not_evicted(): assert num_computed_tokens == block_size blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, + len(computed_blocks.blocks) * 16, computed_blocks) assert len(blocks.blocks) == 1 assert blocks.blocks[0].block_id == 2 @@ -464,7 +499,9 @@ def test_basic_prefix_caching_disabled(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, 10, computed_blocks) + blocks = manager.allocate_slots(req1, 10, + len(computed_blocks.blocks) * 16, + computed_blocks) assert len(blocks.blocks) == 3 # Free the blocks. @@ -475,7 +512,9 @@ def test_basic_prefix_caching_disabled(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req2, 16, computed_blocks) + blocks = manager.allocate_slots(req2, 16, + len(computed_blocks.blocks) * 16, + computed_blocks) assert len(blocks.blocks) == 4 # New requests should not have any blocks. @@ -483,7 +522,9 @@ def test_basic_prefix_caching_disabled(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req3, 4, computed_blocks) + blocks = manager.allocate_slots(req3, 4, + len(computed_blocks.blocks) * 16, + computed_blocks) assert not blocks @@ -581,14 +622,18 @@ def test_mm_prefix_caching(): assert block_hashes[1].extra_keys == ("aaa", "bbb") assert block_hashes[2].extra_keys == ("bbb", ) - blocks = manager.allocate_slots(req0, 59, computed_blocks) + blocks = manager.allocate_slots(req0, 59, + len(computed_blocks.blocks) * 16, + computed_blocks) assert blocks.get_block_ids() == [1, 2, 3, 4] req0.num_computed_tokens = 59 # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) - new_blocks = manager.allocate_slots(req0, 5) + new_blocks = manager.allocate_slots(req0, 5, + len(computed_blocks.blocks) * 16, + computed_blocks) assert new_blocks is not None and len(new_blocks.blocks) == 0 # The just completed block should have hashes with extra keys. @@ -638,14 +683,18 @@ def test_cache_key_salting(): assert block_hashes[1].extra_keys is None assert block_hashes[2].extra_keys is None - blocks = manager.allocate_slots(req0, 59, computed_blocks) + blocks = manager.allocate_slots(req0, 59, + len(computed_blocks.blocks) * 16, + computed_blocks) assert blocks.get_block_ids() == [1, 2, 3, 4] req0.num_computed_tokens = 59 # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) - new_blocks = manager.allocate_slots(req0, 5) + new_blocks = manager.allocate_slots(req0, 5, + len(computed_blocks.blocks) * 16, + computed_blocks) assert new_blocks is not None and len(new_blocks.blocks) == 0 # Now one more block that should not have extra keys. @@ -691,7 +740,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks assert num_computed_tokens == 0 - manager.allocate_slots(req0, 48, computed_blocks) + manager.allocate_slots(req0, 48, + len(computed_blocks.blocks) * 16, computed_blocks) block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | @@ -699,7 +749,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert computed_blocks.blocks == block_part0 assert num_computed_tokens == 3 * 16 - manager.allocate_slots(req1, 48, computed_blocks) + manager.allocate_slots(req1, 48, + len(computed_blocks.blocks) * 16, computed_blocks) block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| ... | @@ -713,7 +764,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks assert num_computed_tokens == 0 - manager.allocate_slots(req2, block_size * 2, computed_blocks) + manager.allocate_slots(req2, block_size * 2, + len(computed_blocks.blocks) * 16, computed_blocks) # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, # but it cannot be allocated due to insufficient free blocks (2). @@ -724,7 +776,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): assert computed_blocks.blocks == block_part1 assert num_computed_tokens == 6 * 16 # Req3 cannot be allocated. - assert manager.allocate_slots(req3, 48, computed_blocks) is None + assert manager.allocate_slots(req3, 48, + len(computed_blocks.blocks) * 16, + computed_blocks) is None # Block 0-2 are used by Req 1. assert {block.ref_cnt for block in block_part1[:3]} == {1} # Block 3-5 are free. @@ -751,7 +805,9 @@ def test_reset_prefix_cache(): computed_blocks, _ = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(computed_blocks.blocks) == 3 - blocks = manager.allocate_slots(req1, 7, computed_blocks) + blocks = manager.allocate_slots(req1, 7, + len(computed_blocks.blocks) * 16, + computed_blocks) assert blocks.get_block_ids() == [5] # Failed to reset prefix cache because some blocks are not freed yet. @@ -782,7 +838,8 @@ def test_prefix_cache_stats_disabled(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks assert num_computed_tokens == 0 - manager.allocate_slots(req, 16, computed_blocks) + manager.allocate_slots(req, 16, + len(computed_blocks.blocks) * 16, computed_blocks) manager.reset_prefix_cache() # Ensure prefix_cache_stats remains None @@ -860,7 +917,8 @@ def test_eagle_enabled_removes_last_block(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) - manager.allocate_slots(req, len(token_ids), computed_blocks) + manager.allocate_slots(req, len(token_ids), + len(computed_blocks.blocks) * 16, computed_blocks) manager.free(req) # New request with same tokens + Eagle enabled @@ -889,7 +947,8 @@ def test_eagle_with_partial_blocks(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) - manager.allocate_slots(req, len(token_ids), computed_blocks) + manager.allocate_slots(req, len(token_ids), + len(computed_blocks.blocks) * 16, computed_blocks) manager.free(req) # New request with Eagle enabled @@ -928,7 +987,8 @@ def test_eagle_with_sliding_window(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) - manager.allocate_slots(req, len(token_ids), computed_blocks) + manager.allocate_slots(req, len(token_ids), + len(computed_blocks.blocks) * 16, computed_blocks) # record the block hash of the first block in the request for later use block_hash_first_block = manager.req_to_block_hashes[req.request_id][0] assert block_hash_first_block is not None diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index b34b53155cc3..79a45f752d5e 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -121,13 +121,6 @@ def get_computed_blocks(self, - A list of blocks that are computed for the request. - The number of computed tokens. """ - - # Request already has blocks from async load via KVConnector. - num_existing_blocks = len( - self.single_type_manager.req_to_blocks[request.request_id]) - if num_existing_blocks > 0: - return KVCacheBlocks.create_empty(), request.num_computed_tokens - # Prefix caching is disabled or # When the request requires prompt logprobs, we skip prefix caching. if (not self.enable_caching @@ -183,6 +176,7 @@ def allocate_slots( self, request: Request, num_new_tokens: int, + num_new_computed_tokens: int = 0, new_computed_blocks: Optional[KVCacheBlocks] = None, num_lookahead_tokens: int = 0, delay_cache_blocks: bool = False, @@ -194,8 +188,10 @@ def allocate_slots( num_new_tokens: The number of tokens to allocate, including external tokens. Note that this does not include tokens that have already been computed locally (i.e. new_computed_blocks). - new_computed_blocks: The new computed blocks just hitting the - prefix caching. + num_new_computed_tokens: The number of new computed tokens just + hitting the prefix caching, excluding external tokens. + new_computed_blocks: The cached blocks for the above new computed + tokens. num_lookahead_tokens: The number of speculative tokens to allocate. This is used by spec decode proposers with kv-cache such as eagle. @@ -240,7 +236,7 @@ def allocate_slots( # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits num_computed_tokens = (request.num_computed_tokens + - len(new_computed_block_list) * self.block_size) + num_new_computed_tokens) num_tokens_need_slot = min( num_computed_tokens + num_new_tokens + num_lookahead_tokens, self.max_model_len) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 7773853b096a..9f051b73c263 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -18,7 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) @@ -311,12 +311,14 @@ def schedule(self) -> SchedulerOutput: break request = self.waiting[0] - + num_prealloc_computed_tokens = 0 # P/D: skip request if still waiting for remote kvs. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: is_ready = self._update_waiting_for_remote_kv(request) if is_ready: request.status = RequestStatus.WAITING + num_prealloc_computed_tokens = ( + request.num_computed_tokens) else: self.waiting.popleft() skipped_waiting_requests.appendleft(request) @@ -345,18 +347,25 @@ def schedule(self) -> SchedulerOutput: continue # Get already-cached tokens. - new_computed_blocks, num_computed_tokens = \ - self.kv_cache_manager.get_computed_blocks( - request) + if num_prealloc_computed_tokens == 0: + new_computed_blocks, num_native_computed_tokens = \ + self.kv_cache_manager.get_computed_blocks( + request) + else: + # P/D: skip checking prefix cache if loaded from remote kvs. + new_computed_blocks = KVCacheBlocks.create_empty() + num_native_computed_tokens = 0 # Get externally-cached tokens if using a KVConnector. - num_external_tokens, load_kv_async = ( + num_external_computed_tokens, load_kv_async = ( (0, False) if self.connector is None else self.connector.get_num_new_matched_tokens( - request, num_computed_tokens)) + request, num_native_computed_tokens)) # Total computed tokens (local + external). - num_computed_tokens += num_external_tokens + num_computed_tokens = (num_native_computed_tokens + + num_external_computed_tokens + + num_prealloc_computed_tokens) encoder_inputs_to_schedule = None new_encoder_budget = encoder_budget @@ -390,7 +399,8 @@ def schedule(self) -> SchedulerOutput: new_blocks = self.kv_cache_manager.allocate_slots( request, - num_new_tokens + num_external_tokens, + num_new_tokens + num_external_computed_tokens, + num_native_computed_tokens, new_computed_blocks, num_lookahead_tokens=self.num_lookahead_tokens, delay_cache_blocks=load_kv_async, @@ -406,7 +416,7 @@ def schedule(self) -> SchedulerOutput: self.connector.update_state_after_alloc( request, new_computed_blocks + new_blocks, - num_external_tokens, + num_external_computed_tokens, ) self.waiting.popleft()