From c7c1bbcaf2147940f3aed0d0cfe203ae5a8d01a0 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 24 Apr 2025 14:24:51 -0700 Subject: [PATCH 1/4] version Signed-off-by: LiuXiaoxuanPKU --- tests/v1/core/test_prefix_caching.py | 137 +++++++++++++++++++++------ tests/v1/e2e/test_spec_decode.py | 48 +++++++++- vllm/v1/core/kv_cache_manager.py | 17 +++- vllm/v1/core/sched/scheduler.py | 3 +- 4 files changed, 175 insertions(+), 30 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 1b238d47c03a..a66fd52583ad 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -74,7 +74,8 @@ def test_prefill(hash_algo): unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req0, False) assert len(manager.req_to_block_hashes[req0.request_id]) == 3 assert not computed_blocks assert num_computed_tokens == 0 @@ -100,7 +101,8 @@ def test_prefill(hash_algo): # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req1, False) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert [b.block_id for b in computed_blocks] == [1, 2, 3] assert num_computed_tokens == 3 * 16 @@ -132,7 +134,8 @@ def test_prefill(hash_algo): # Incomplete 1 block (6 tokens) unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req2, False) assert len(manager.req_to_block_hashes[req2.request_id]) == 3 assert [b.block_id for b in computed_blocks] == [1, 2, 3] assert num_computed_tokens == 3 * 16 @@ -155,7 +158,8 @@ def test_prefill(hash_algo): # Cache miss and eviction. req3 = make_request("3", [99] * (16 * 10)) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req3, False) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks) @@ -190,7 +194,8 @@ def test_prefill_plp(): unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, prompt_logprobs=5) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req0, False) assert len(manager.req_to_block_hashes[req0.request_id]) == 3 assert not computed_blocks assert num_computed_tokens == 0 @@ -218,7 +223,8 @@ def test_prefill_plp(): # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req1, False) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert [b.block_id for b in computed_blocks] == [1, 2, 3] assert num_computed_tokens == 3 * 16 @@ -252,7 +258,8 @@ def test_prefill_plp(): req2 = make_request("2", common_token_ids + unique_token_ids, prompt_logprobs=5) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req2, False) assert len(manager.req_to_block_hashes[req2.request_id]) == 3 assert not computed_blocks assert num_computed_tokens == 0 @@ -284,7 +291,8 @@ def test_decode(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 req0 = make_request("0", common_token_ids + unique_token_ids) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req0, False) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) @@ -319,7 +327,8 @@ def test_evict(): last_token_id = 5 * 16 + 7 req0 = make_request("0", list(range(last_token_id))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req0, False) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) @@ -328,7 +337,8 @@ def test_evict(): # 3 blocks. req1 = make_request("1", list(range(last_token_id, last_token_id + 3 * 16))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req1, False) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks) @@ -348,7 +358,8 @@ def test_evict(): # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req2, False) assert [b.block_id for b in computed_blocks] == [1, 2] assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, computed_blocks) @@ -371,7 +382,8 @@ def test_hash_block_correct_reuse(): # Allocate 1 block and cache it. num_tokens = block_size * 1 req = make_request("0", list(range(num_tokens))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req, False) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req, num_tokens, computed_blocks) @@ -383,7 +395,8 @@ def test_hash_block_correct_reuse(): # Allocate a new block that's not full, make sure hash info on the # block is cleared. req = make_request("1", list(range(num_tokens - 1))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req, False) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks) @@ -407,7 +420,8 @@ def test_computed_blocks_not_evicted(): # Allocate a block and cache it. num_tokens = block_size * 1 req0 = make_request("0", list(range(num_tokens))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req0, False) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, num_tokens, computed_blocks) @@ -416,7 +430,8 @@ def test_computed_blocks_not_evicted(): # Allocate another block. req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req1, False) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, num_tokens, computed_blocks) @@ -430,7 +445,8 @@ def test_computed_blocks_not_evicted(): # Now if we have a cache hit on the first block, we should evict the second # cached block rather than the first one. req2 = make_request("2", list(range(num_tokens * 2))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req2, False) assert len(computed_blocks) == 1 assert computed_blocks[0].block_id == 1 assert num_computed_tokens == block_size @@ -454,7 +470,8 @@ def test_basic_prefix_caching_disabled(): req1 = make_request("1", list(range(10))) # 2 blocks and some more - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req1, False) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, 10, computed_blocks) @@ -465,7 +482,8 @@ def test_basic_prefix_caching_disabled(): # No caching. req2 = make_request("2", list(range(16))) # shared prefix - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req2, False) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 16, computed_blocks) @@ -473,7 +491,8 @@ def test_basic_prefix_caching_disabled(): # New requests should not have any blocks. req3 = make_request("3", list(range(4))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req3, False) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 4, computed_blocks) @@ -563,7 +582,8 @@ def test_mm_prefix_caching(): all_token_ids, mm_positions=mm_positions, mm_hashes=mm_hashes) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req0, False) # Completed block should have hashes with extra keys. assert not computed_blocks @@ -599,7 +619,8 @@ def test_mm_prefix_caching(): all_token_ids, mm_positions=mm_positions, mm_hashes=mm_hashes) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req1, False) assert len(computed_blocks) == 3 assert num_computed_tokens == 3 * 16 @@ -621,7 +642,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | ... | common_token_ids = [i for i in range(3) for _ in range(16)] req0 = make_request("0", common_token_ids) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req0, False) assert not computed_blocks assert num_computed_tokens == 0 manager.allocate_slots(req0, 48, computed_blocks) @@ -629,7 +651,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | req1 = make_request("1", common_token_ids * 2) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req1, False) assert computed_blocks == block_part0 assert num_computed_tokens == 3 * 16 manager.allocate_slots(req1, 48, computed_blocks) @@ -643,7 +666,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| Req2-0 | Req2-1 | ... | req2 = make_request("2", [7] * block_size * 2) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req2, False) assert not computed_blocks assert num_computed_tokens == 0 manager.allocate_slots(req2, block_size * 2, computed_blocks) @@ -653,7 +677,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # In this case, the ref_cnt of the computed blocks should not be changed. assert manager.block_pool.free_block_queue.num_free_blocks == 5 req3 = make_request("3", common_token_ids * 3) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req3, False) assert computed_blocks == block_part1 assert num_computed_tokens == 6 * 16 # Req3 cannot be allocated. @@ -681,7 +706,7 @@ def test_reset_prefix_cache(): unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids req1 = make_request("1", all_token_ids) - computed_blocks, _ = manager.get_computed_blocks(req1) + computed_blocks, _ = manager.get_computed_blocks(req1, False) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(computed_blocks) == 3 blocks = manager.allocate_slots(req1, 7, computed_blocks) @@ -712,7 +737,8 @@ def test_prefix_cache_stats_disabled(): # Call all functions that check whether log_stats is disabled. req = make_request("0", list(range(16))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) + computed_blocks, num_computed_tokens = manager.get_computed_blocks( + req, False) assert not computed_blocks assert num_computed_tokens == 0 manager.allocate_slots(req, 16, computed_blocks) @@ -720,3 +746,60 @@ def test_prefix_cache_stats_disabled(): # Ensure prefix_cache_stats remains None assert manager.prefix_cache_stats is None + + +def test_eagle_enabled_removes_last_block(): + """Verify Eagle does NOT remove blocks when request + length is divisible by block size.""" + block_size = 16 + manager = KVCacheManager( + make_kv_cache_config(block_size, num_blocks=10), + max_model_len=8192, + enable_caching=True, + ) + + # Request with 3 full blocks (48 tokens) + token_ids = [0] * (3 * block_size) + req = make_request("divisible_request", token_ids) + + # Prime the cache + computed_blocks, _ = manager.get_computed_blocks(req, enable_eagle=False) + manager.allocate_slots(req, len(token_ids), computed_blocks) + manager.free(req) + + # New request with same tokens + Eagle enabled + req_eagle = make_request("eagle_divisible", token_ids) + computed_blocks, num_tokens = manager.get_computed_blocks( + req_eagle, enable_eagle=True) + + # Should retain 2 blocks: + # 1. Original 3 blocks → pop last hash → 2 matched blocks + # 2. last_block_hash is not None → Eagle pop is SKIPPED + assert len(computed_blocks) == 2 + assert num_tokens == 2 * block_size # 32 tokens + + +def test_eagle_with_partial_blocks(): + """Test Eagle behavior with requests containing partial blocks.""" + block_size = 16 + manager = KVCacheManager( + make_kv_cache_config(block_size, num_blocks=10), + max_model_len=8192, + enable_caching=True, + ) + # 2 full blocks + 5 tokens (non-divisible length) + token_ids = [0] * (2 * block_size + 5) + req = make_request("partial_block_test", token_ids) + + # Prime the cache + computed_blocks, _ = manager.get_computed_blocks(req, enable_eagle=False) + manager.allocate_slots(req, len(token_ids), computed_blocks) + manager.free(req) + + # New request with Eagle enabled + req_eagle = make_request("partial_eagle", token_ids) + computed_blocks, num_tokens = manager.get_computed_blocks( + req_eagle, enable_eagle=True) + # Original match: 2 full blocks → Eagle removes 1 → 1 remaining + assert len(computed_blocks) == 1 + assert num_tokens == 1 * block_size diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 673714980592..377463e2e9c5 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -44,7 +44,6 @@ def test_prompts(): @pytest.fixture def sampling_config(): - # Only support greedy for now return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) @@ -144,3 +143,50 @@ def test_eagle_correctness( # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.7 * len(ref_outputs)) del spec_llm + + +def test_prefix_cache( + monkeypatch: pytest.MonkeyPatch, + test_prompts: list[list[dict[str, Any]]], + sampling_config: SamplingParams, + model_name: str, + eagle_model_name: str, +): + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + # Populate the cache with a few prompts + spec_llm = LLM(model=model_name, + speculative_config={ + "method": "eagle", + "model": eagle_model_name, + "num_speculative_tokens": 3, + }, + max_model_len=1024, + disable_log_stats=False) + + first_outputs = spec_llm.chat(test_prompts, sampling_config) + + # Now, run the same prompts again to check if the cache is used + second_outputs = spec_llm.chat(test_prompts, sampling_config) + + # Check: + # 1. The output is the almost same as the first run. + # 2. There is cache hit. + matches = 0 + misses = 0 + for first_output, second_output in zip(first_outputs, second_outputs): + if first_output.outputs[0].text == second_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {first_output.outputs[0].text}") + print(f"spec_output: {second_output.outputs[0].text}") + + # Heuristic: expect at least 70% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.9 * len(first_outputs)) + + print(spec_llm.llm_engine.engine_core.core_engine) + + del spec_llm diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 354300d3c2fe..373891754613 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -88,12 +88,17 @@ def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: return stats def get_computed_blocks( - self, request: Request) -> tuple[list[KVCacheBlock], int]: + self, request: Request, + enable_eagle: bool) -> tuple[list[KVCacheBlock], int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. Args: request: The request to get the computed blocks. + enable_eagle: Whether to enable eagle spec decode. If True, + we will drop the last matched block so that we can recompute + the last block to get the required hidden states for eagle + drafting head. Returns: A tuple containing: @@ -134,6 +139,16 @@ def get_computed_blocks( computed_blocks = ( self.specialized_manager.find_longest_cache_hit(block_hashes)) + + if enable_eagle and len( + computed_blocks) > 0 and last_block_hash is None: + # Drop the last matched block if (1) eagle is enabled and + # (2) there is a cache hit and (3) the last block hash is + # not removed. + # This is to recompute the last block to get the required + # hidden states for eagle drafting head. + computed_blocks.pop() + if self.log_stats: assert self.prefix_cache_stats is not None self.prefix_cache_stats.queries += len(block_hashes) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 5adcdde5bcd7..dc248c03bd54 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -321,7 +321,8 @@ def schedule(self) -> SchedulerOutput: # Get already-cached tokens. computed_blocks, num_computed_tokens = \ - self.kv_cache_manager.get_computed_blocks(request) + self.kv_cache_manager.get_computed_blocks( + request, self.num_lookahead_tokens > 0) # Get externally-cached tokens if using a KVConnector. num_external_tokens = ( From a77256feb6a032181cc4a2999d0cac424e4eba97 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Fri, 25 Apr 2025 13:59:17 -0700 Subject: [PATCH 2/4] fix comments Signed-off-by: LiuXiaoxuanPKU --- tests/v1/core/test_prefix_caching.py | 18 +++++------ tests/v1/e2e/test_spec_decode.py | 47 ---------------------------- vllm/v1/core/kv_cache_manager.py | 13 +++----- vllm/v1/core/sched/scheduler.py | 4 ++- 4 files changed, 17 insertions(+), 65 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index a66fd52583ad..58019797c4fd 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -763,20 +763,20 @@ def test_eagle_enabled_removes_last_block(): req = make_request("divisible_request", token_ids) # Prime the cache - computed_blocks, _ = manager.get_computed_blocks(req, enable_eagle=False) + computed_blocks, _ = manager.get_computed_blocks(req, use_eagle=False) manager.allocate_slots(req, len(token_ids), computed_blocks) manager.free(req) # New request with same tokens + Eagle enabled req_eagle = make_request("eagle_divisible", token_ids) - computed_blocks, num_tokens = manager.get_computed_blocks( - req_eagle, enable_eagle=True) + computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle, + use_eagle=True) # Should retain 2 blocks: # 1. Original 3 blocks → pop last hash → 2 matched blocks - # 2. last_block_hash is not None → Eagle pop is SKIPPED - assert len(computed_blocks) == 2 - assert num_tokens == 2 * block_size # 32 tokens + # 2. last_block_hash is not None → Eagle pop is not SKIPPED + assert len(computed_blocks) == 1 + assert num_tokens == 1 * block_size # 32 tokens def test_eagle_with_partial_blocks(): @@ -792,14 +792,14 @@ def test_eagle_with_partial_blocks(): req = make_request("partial_block_test", token_ids) # Prime the cache - computed_blocks, _ = manager.get_computed_blocks(req, enable_eagle=False) + computed_blocks, _ = manager.get_computed_blocks(req, use_eagle=False) manager.allocate_slots(req, len(token_ids), computed_blocks) manager.free(req) # New request with Eagle enabled req_eagle = make_request("partial_eagle", token_ids) - computed_blocks, num_tokens = manager.get_computed_blocks( - req_eagle, enable_eagle=True) + computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle, + use_eagle=True) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks) == 1 assert num_tokens == 1 * block_size diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 377463e2e9c5..1c55beeea877 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -143,50 +143,3 @@ def test_eagle_correctness( # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.7 * len(ref_outputs)) del spec_llm - - -def test_prefix_cache( - monkeypatch: pytest.MonkeyPatch, - test_prompts: list[list[dict[str, Any]]], - sampling_config: SamplingParams, - model_name: str, - eagle_model_name: str, -): - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - - # Populate the cache with a few prompts - spec_llm = LLM(model=model_name, - speculative_config={ - "method": "eagle", - "model": eagle_model_name, - "num_speculative_tokens": 3, - }, - max_model_len=1024, - disable_log_stats=False) - - first_outputs = spec_llm.chat(test_prompts, sampling_config) - - # Now, run the same prompts again to check if the cache is used - second_outputs = spec_llm.chat(test_prompts, sampling_config) - - # Check: - # 1. The output is the almost same as the first run. - # 2. There is cache hit. - matches = 0 - misses = 0 - for first_output, second_output in zip(first_outputs, second_outputs): - if first_output.outputs[0].text == second_output.outputs[0].text: - matches += 1 - else: - misses += 1 - print(f"ref_output: {first_output.outputs[0].text}") - print(f"spec_output: {second_output.outputs[0].text}") - - # Heuristic: expect at least 70% of the prompts to match exactly - # Upon failure, inspect the outputs to check for inaccuracy. - assert matches > int(0.9 * len(first_outputs)) - - print(spec_llm.llm_engine.engine_core.core_engine) - - del spec_llm diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 373891754613..ca369893d3e6 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -87,15 +87,14 @@ def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: self.prefix_cache_stats = PrefixCacheStats() return stats - def get_computed_blocks( - self, request: Request, - enable_eagle: bool) -> tuple[list[KVCacheBlock], int]: + def get_computed_blocks(self, request: Request, + use_eagle: bool) -> tuple[list[KVCacheBlock], int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. Args: request: The request to get the computed blocks. - enable_eagle: Whether to enable eagle spec decode. If True, + use_eagle: Whether to enable eagle spec decode. If True, we will drop the last matched block so that we can recompute the last block to get the required hidden states for eagle drafting head. @@ -140,11 +139,9 @@ def get_computed_blocks( computed_blocks = ( self.specialized_manager.find_longest_cache_hit(block_hashes)) - if enable_eagle and len( - computed_blocks) > 0 and last_block_hash is None: + if use_eagle and len(computed_blocks) > 0: # Drop the last matched block if (1) eagle is enabled and - # (2) there is a cache hit and (3) the last block hash is - # not removed. + # (2) there is a cache hit. # This is to recompute the last block to get the required # hidden states for eagle drafting head. computed_blocks.pop() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index dc248c03bd54..a252d39b53b4 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -126,7 +126,9 @@ def __init__( self.num_lookahead_tokens = 0 speculative_config = vllm_config.speculative_config + self.use_eagle = False if speculative_config and speculative_config.method == "eagle": + self.use_eagle = True self.num_lookahead_tokens = \ speculative_config.num_speculative_tokens @@ -322,7 +324,7 @@ def schedule(self) -> SchedulerOutput: # Get already-cached tokens. computed_blocks, num_computed_tokens = \ self.kv_cache_manager.get_computed_blocks( - request, self.num_lookahead_tokens > 0) + request, self.use_eagle) # Get externally-cached tokens if using a KVConnector. num_external_tokens = ( From 9c1d37c9c130c58506e26c5a2f6a7f1424ea1e8d Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 26 Apr 2025 21:17:23 -0700 Subject: [PATCH 3/4] fix comments/ Signed-off-by: LiuXiaoxuanPKU --- tests/v1/core/test_prefix_caching.py | 92 ++++++++++------------------ vllm/v1/core/kv_cache_manager.py | 8 ++- vllm/v1/core/sched/scheduler.py | 18 +++--- 3 files changed, 48 insertions(+), 70 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 58019797c4fd..094ff008b710 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -74,8 +74,7 @@ def test_prefill(hash_algo): unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req0, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(manager.req_to_block_hashes[req0.request_id]) == 3 assert not computed_blocks assert num_computed_tokens == 0 @@ -101,8 +100,7 @@ def test_prefill(hash_algo): # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req1, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert [b.block_id for b in computed_blocks] == [1, 2, 3] assert num_computed_tokens == 3 * 16 @@ -134,8 +132,7 @@ def test_prefill(hash_algo): # Incomplete 1 block (6 tokens) unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req2, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(manager.req_to_block_hashes[req2.request_id]) == 3 assert [b.block_id for b in computed_blocks] == [1, 2, 3] assert num_computed_tokens == 3 * 16 @@ -158,8 +155,7 @@ def test_prefill(hash_algo): # Cache miss and eviction. req3 = make_request("3", [99] * (16 * 10)) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req3, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks) @@ -194,8 +190,7 @@ def test_prefill_plp(): unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, prompt_logprobs=5) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req0, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(manager.req_to_block_hashes[req0.request_id]) == 3 assert not computed_blocks assert num_computed_tokens == 0 @@ -223,8 +218,7 @@ def test_prefill_plp(): # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req1, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert [b.block_id for b in computed_blocks] == [1, 2, 3] assert num_computed_tokens == 3 * 16 @@ -258,8 +252,7 @@ def test_prefill_plp(): req2 = make_request("2", common_token_ids + unique_token_ids, prompt_logprobs=5) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req2, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(manager.req_to_block_hashes[req2.request_id]) == 3 assert not computed_blocks assert num_computed_tokens == 0 @@ -291,8 +284,7 @@ def test_decode(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 req0 = make_request("0", common_token_ids + unique_token_ids) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req0, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) @@ -327,8 +319,7 @@ def test_evict(): last_token_id = 5 * 16 + 7 req0 = make_request("0", list(range(last_token_id))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req0, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) @@ -337,8 +328,7 @@ def test_evict(): # 3 blocks. req1 = make_request("1", list(range(last_token_id, last_token_id + 3 * 16))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req1, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks) @@ -358,8 +348,7 @@ def test_evict(): # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req2, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert [b.block_id for b in computed_blocks] == [1, 2] assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, computed_blocks) @@ -382,8 +371,7 @@ def test_hash_block_correct_reuse(): # Allocate 1 block and cache it. num_tokens = block_size * 1 req = make_request("0", list(range(num_tokens))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req, num_tokens, computed_blocks) @@ -395,8 +383,7 @@ def test_hash_block_correct_reuse(): # Allocate a new block that's not full, make sure hash info on the # block is cleared. req = make_request("1", list(range(num_tokens - 1))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks) @@ -420,8 +407,7 @@ def test_computed_blocks_not_evicted(): # Allocate a block and cache it. num_tokens = block_size * 1 req0 = make_request("0", list(range(num_tokens))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req0, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, num_tokens, computed_blocks) @@ -430,8 +416,7 @@ def test_computed_blocks_not_evicted(): # Allocate another block. req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req1, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, num_tokens, computed_blocks) @@ -445,8 +430,7 @@ def test_computed_blocks_not_evicted(): # Now if we have a cache hit on the first block, we should evict the second # cached block rather than the first one. req2 = make_request("2", list(range(num_tokens * 2))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req2, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks) == 1 assert computed_blocks[0].block_id == 1 assert num_computed_tokens == block_size @@ -470,8 +454,7 @@ def test_basic_prefix_caching_disabled(): req1 = make_request("1", list(range(10))) # 2 blocks and some more - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req1, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, 10, computed_blocks) @@ -482,8 +465,7 @@ def test_basic_prefix_caching_disabled(): # No caching. req2 = make_request("2", list(range(16))) # shared prefix - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req2, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 16, computed_blocks) @@ -491,8 +473,7 @@ def test_basic_prefix_caching_disabled(): # New requests should not have any blocks. req3 = make_request("3", list(range(4))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req3, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 4, computed_blocks) @@ -582,8 +563,7 @@ def test_mm_prefix_caching(): all_token_ids, mm_positions=mm_positions, mm_hashes=mm_hashes) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req0, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes with extra keys. assert not computed_blocks @@ -619,8 +599,7 @@ def test_mm_prefix_caching(): all_token_ids, mm_positions=mm_positions, mm_hashes=mm_hashes) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req1, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(computed_blocks) == 3 assert num_computed_tokens == 3 * 16 @@ -642,8 +621,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | ... | common_token_ids = [i for i in range(3) for _ in range(16)] req0 = make_request("0", common_token_ids) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req0, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks assert num_computed_tokens == 0 manager.allocate_slots(req0, 48, computed_blocks) @@ -651,8 +629,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | req1 = make_request("1", common_token_ids * 2) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req1, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert computed_blocks == block_part0 assert num_computed_tokens == 3 * 16 manager.allocate_slots(req1, 48, computed_blocks) @@ -666,8 +643,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| Req2-0 | Req2-1 | ... | req2 = make_request("2", [7] * block_size * 2) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req2, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks assert num_computed_tokens == 0 manager.allocate_slots(req2, block_size * 2, computed_blocks) @@ -677,8 +653,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # In this case, the ref_cnt of the computed blocks should not be changed. assert manager.block_pool.free_block_queue.num_free_blocks == 5 req3 = make_request("3", common_token_ids * 3) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req3, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert computed_blocks == block_part1 assert num_computed_tokens == 6 * 16 # Req3 cannot be allocated. @@ -706,7 +681,7 @@ def test_reset_prefix_cache(): unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids req1 = make_request("1", all_token_ids) - computed_blocks, _ = manager.get_computed_blocks(req1, False) + computed_blocks, _ = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(computed_blocks) == 3 blocks = manager.allocate_slots(req1, 7, computed_blocks) @@ -737,8 +712,7 @@ def test_prefix_cache_stats_disabled(): # Call all functions that check whether log_stats is disabled. req = make_request("0", list(range(16))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks( - req, False) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks assert num_computed_tokens == 0 manager.allocate_slots(req, 16, computed_blocks) @@ -756,6 +730,7 @@ def test_eagle_enabled_removes_last_block(): make_kv_cache_config(block_size, num_blocks=10), max_model_len=8192, enable_caching=True, + use_eagle=True, ) # Request with 3 full blocks (48 tokens) @@ -763,14 +738,13 @@ def test_eagle_enabled_removes_last_block(): req = make_request("divisible_request", token_ids) # Prime the cache - computed_blocks, _ = manager.get_computed_blocks(req, use_eagle=False) + computed_blocks, _ = manager.get_computed_blocks(req) manager.allocate_slots(req, len(token_ids), computed_blocks) manager.free(req) # New request with same tokens + Eagle enabled req_eagle = make_request("eagle_divisible", token_ids) - computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle, - use_eagle=True) + computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Should retain 2 blocks: # 1. Original 3 blocks → pop last hash → 2 matched blocks @@ -786,20 +760,20 @@ def test_eagle_with_partial_blocks(): make_kv_cache_config(block_size, num_blocks=10), max_model_len=8192, enable_caching=True, + use_eagle=True, ) # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) req = make_request("partial_block_test", token_ids) # Prime the cache - computed_blocks, _ = manager.get_computed_blocks(req, use_eagle=False) + computed_blocks, _ = manager.get_computed_blocks(req) manager.allocate_slots(req, len(token_ids), computed_blocks) manager.free(req) # New request with Eagle enabled req_eagle = make_request("partial_eagle", token_ids) - computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle, - use_eagle=True) + computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks) == 1 assert num_tokens == 1 * block_size diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index ca369893d3e6..bf83fc3b51cb 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -25,6 +25,7 @@ def __init__( max_model_len: int, enable_caching: bool = True, caching_hash_algo: str = "builtin", + use_eagle: bool = False, log_stats: bool = False, ) -> None: assert len(kv_cache_config.kv_cache_groups) == 1, ( @@ -38,6 +39,7 @@ def __init__( self.enable_caching = enable_caching self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash + self.use_eagle = use_eagle self.log_stats = log_stats # FIXME: make prefix cache stats conditional on log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None @@ -87,8 +89,8 @@ def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: self.prefix_cache_stats = PrefixCacheStats() return stats - def get_computed_blocks(self, request: Request, - use_eagle: bool) -> tuple[list[KVCacheBlock], int]: + def get_computed_blocks( + self, request: Request) -> tuple[list[KVCacheBlock], int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -139,7 +141,7 @@ def get_computed_blocks(self, request: Request, computed_blocks = ( self.specialized_manager.find_longest_cache_hit(block_hashes)) - if use_eagle and len(computed_blocks) > 0: + if self.use_eagle and len(computed_blocks) > 0: # Drop the last matched block if (1) eagle is enabled and # (2) there is a cache hit. # This is to recompute the last block to get the required diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d31fedc633b3..7c3fd3545d17 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -74,13 +74,6 @@ def __init__( num_gpu_blocks = self.cache_config.num_gpu_blocks assert num_gpu_blocks is not None and num_gpu_blocks > 0 - # Create the KV cache manager. - self.kv_cache_manager = KVCacheManager( - kv_cache_config=kv_cache_config, - max_model_len=self.max_model_len, - enable_caching=self.cache_config.enable_prefix_caching, - caching_hash_algo=self.cache_config.prefix_caching_hash_algo, - log_stats=self.log_stats) self.block_size = self.cache_config.block_size # req_id -> Request @@ -132,6 +125,15 @@ def __init__( self.use_eagle = True self.num_lookahead_tokens = self.num_spec_tokens + # Create the KV cache manager. + self.kv_cache_manager = KVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=self.max_model_len, + enable_caching=self.cache_config.enable_prefix_caching, + caching_hash_algo=self.cache_config.prefix_caching_hash_algo, + use_eagle=self.use_eagle, + log_stats=self.log_stats) + def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -321,7 +323,7 @@ def schedule(self) -> SchedulerOutput: # Get already-cached tokens. computed_blocks, num_computed_tokens = \ self.kv_cache_manager.get_computed_blocks( - request, self.use_eagle) + request) # Get externally-cached tokens if using a KVConnector. num_external_tokens = ( From 1f2a63038ba7c5199215e9ed9733d08afbf64d7d Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 26 Apr 2025 21:44:08 -0700 Subject: [PATCH 4/4] minor Signed-off-by: LiuXiaoxuanPKU --- vllm/v1/core/kv_cache_manager.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index bf83fc3b51cb..0830d8433d89 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -96,10 +96,6 @@ def get_computed_blocks( Args: request: The request to get the computed blocks. - use_eagle: Whether to enable eagle spec decode. If True, - we will drop the last matched block so that we can recompute - the last block to get the required hidden states for eagle - drafting head. Returns: A tuple containing: