From d56c0ad8842e10fb384205cecbaa86198f502c3c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 29 Apr 2025 09:16:06 -0700 Subject: [PATCH 1/4] save Signed-off-by: Chen Zhang --- tests/v1/core/test_prefix_caching.py | 63 ++++++++++++++++++++++++++-- vllm/v1/core/kv_cache_manager.py | 7 ---- vllm/v1/core/specialized_manager.py | 41 +++++++++++++----- 3 files changed, 90 insertions(+), 21 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index b2e8ff61450c..f204c903fbc5 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -14,7 +14,7 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, hash_block_tokens) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec) + KVCacheGroupSpec, SlidingWindowSpec) def make_request(request_id, @@ -745,11 +745,11 @@ def test_eagle_enabled_removes_last_block(): req_eagle = make_request("eagle_divisible", token_ids) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) - # Should retain 2 blocks: + # Should retain 1 block: # 1. Original 3 blocks → pop last hash → 2 matched blocks - # 2. last_block_hash is not None → Eagle pop is not SKIPPED + # 2. drop last matched block → 1 remaining block assert len(computed_blocks) == 1 - assert num_tokens == 1 * block_size # 32 tokens + assert num_tokens == 1 * block_size # 16 tokens def test_eagle_with_partial_blocks(): @@ -776,3 +776,58 @@ def test_eagle_with_partial_blocks(): # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks) == 1 assert num_tokens == 1 * block_size + + +def test_eagle_with_sliding_window(): + """Test Eagle behavior with sliding window.""" + block_size = 16 + sliding_window_spec = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=block_size, + use_mla=False, + ) + manager = KVCacheManager( + KVCacheConfig( + num_blocks=10, + tensors={}, + kv_cache_groups=[KVCacheGroupSpec(['layer'], sliding_window_spec)], + ), + max_model_len=8192, + enable_caching=True, + use_eagle=True, + ) + + # 2 full blocks + 5 tokens (non-divisible length) + token_ids = [0] * (2 * block_size + 5) + req = make_request("partial_block_test", token_ids) + + # Prime the cache + computed_blocks, _ = manager.get_computed_blocks(req) + manager.allocate_slots(req, len(token_ids), computed_blocks) + manager.free(req) + + # New request with Eagle enabled + req_eagle = make_request("partial_eagle", token_ids) + computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) + # Original match: 2 full blocks → Eagle removes 1 → 1 remaining + assert len(computed_blocks) == 1 + assert num_tokens == 1 * block_size + + # Evict the first block in the request + block_hash_first_block = computed_blocks[0].block_hash + assert block_hash_first_block is not None + assert manager.block_pool.get_cached_block( + block_hash_first_block) is not None + manager.block_pool.cached_block_hash_to_block.pop(block_hash_first_block) + + # New request + req_after_evict = make_request("partial_eagle_after_evict", token_ids) + computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict) + # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is + # not considered. But after dropping the last matched block due to eagle, + # there will be no matched prefix. + assert len(computed_blocks) == 0 + assert num_tokens == 0 diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 0830d8433d89..8e2328652747 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -137,13 +137,6 @@ def get_computed_blocks( computed_blocks = ( self.specialized_manager.find_longest_cache_hit(block_hashes)) - if self.use_eagle and len(computed_blocks) > 0: - # Drop the last matched block if (1) eagle is enabled and - # (2) there is a cache hit. - # This is to recompute the last block to get the required - # hidden states for eagle drafting head. - computed_blocks.pop() - if self.log_stats: assert self.prefix_cache_stats is not None self.prefix_cache_stats.queries += len(block_hashes) diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py index 7a8a98361c7e..ef136f678bf2 100644 --- a/vllm/v1/core/specialized_manager.py +++ b/vllm/v1/core/specialized_manager.py @@ -18,6 +18,7 @@ def __init__( self, kv_cache_spec: KVCacheSpec, block_pool: BlockPool, + use_eagle: bool, ) -> None: """ Initializes the SpecializedManager. @@ -30,12 +31,17 @@ def __init__( self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool + # Needs special handling for find_longest_cache_hit if eagle is enabled + self.use_eagle = use_eagle + @abstractmethod def find_longest_cache_hit( self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: """ Get the longest cache hit prefix of the blocks. If no cache hit is - found, return an empty list. + found, return an empty list. if eagle is enabled, drop the last matched + block to force recompute the last block to get the required hidden + states for eagle drafting head. Args: block_hashes: The block hashes of the request. @@ -79,6 +85,8 @@ def find_longest_cache_hit( computed_blocks.append(cached_block) else: break + if self.use_eagle and len(computed_blocks) > 0: + computed_blocks.pop() return computed_blocks def remove_skipped_blocks(self, blocks: list[KVCacheBlock], @@ -89,14 +97,22 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock], class SlidingWindowManager(SpecializedManager): - def __init__(self, kv_cache_spec: SlidingWindowSpec, - block_pool: BlockPool): - super().__init__(kv_cache_spec, block_pool) + def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, + use_eagle: bool): + super().__init__(kv_cache_spec, block_pool, use_eagle) self.sliding_window = kv_cache_spec.sliding_window # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window self.sliding_window_contiguous_blocks = cdiv( (kv_cache_spec.sliding_window - 1), self.block_size) + print("conti b eagle", self.sliding_window_contiguous_blocks) + if self.use_eagle: + # Need to drop the last matched block if eagle is enabled. For + # sliding window layer, we achieve this by increasing the number of + # contiguous blocks needed for prefix cache hit by one and dropping + # the last matched block. + self.sliding_window_contiguous_blocks += 1 + print("conti b eagle", self.sliding_window_contiguous_blocks) self._null_block = block_pool.null_block def find_longest_cache_hit( @@ -109,6 +125,7 @@ def find_longest_cache_hit( computed_blocks = [self._null_block] * len(block_hashes) num_contiguous_blocks = 0 + match_found = False # Search from right to left and early stop when a match is found. for i in range(len(block_hashes) - 1, -1, -1): if cached_block := self.block_pool.get_cached_block( @@ -121,12 +138,16 @@ def find_longest_cache_hit( # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] # when sliding_window_contiguous_blocks=2. del computed_blocks[i + num_contiguous_blocks:] - return computed_blocks + match_found = True + break else: num_contiguous_blocks = 0 - # The first `num_contiguous_blocks` is a cache hit even if - # `num_contiguous_blocks < sliding_window_contiguous_blocks`. - del computed_blocks[num_contiguous_blocks:] + if not match_found: + # The first `num_contiguous_blocks` is a cache hit even if + # `num_contiguous_blocks < sliding_window_contiguous_blocks`. + del computed_blocks[num_contiguous_blocks:] + if self.use_eagle and len(computed_blocks) > 0: + computed_blocks.pop() return computed_blocks def remove_skipped_blocks(self, blocks: list[KVCacheBlock], @@ -155,7 +176,7 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock], def get_specialized_manager(kv_cache_spec: KVCacheSpec, - block_pool: BlockPool) -> SpecializedManager: + **kwargs) -> SpecializedManager: manager_class = spec_manager_map[type(kv_cache_spec)] - manager = manager_class(kv_cache_spec, block_pool) + manager = manager_class(kv_cache_spec, **kwargs) return manager From 0f89a0ed0bc86eb33ec4f098951a8dbfd38d4e39 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 29 Apr 2025 09:32:52 -0700 Subject: [PATCH 2/4] fix test Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 8e2328652747..2d959aaff45e 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -48,6 +48,7 @@ def __init__( self.specialized_manager = get_specialized_manager( kv_cache_spec=kv_cache_spec, block_pool=self.block_pool, + use_eagle=self.use_eagle, ) # Mapping from request ID to blocks to track the blocks allocated From b9ac04670576f997c7810a093f8da06ccfcf823d Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 29 Apr 2025 09:38:58 -0700 Subject: [PATCH 3/4] fix test Signed-off-by: Chen Zhang --- tests/v1/core/test_prefix_caching.py | 5 +++-- vllm/v1/core/specialized_manager.py | 2 -- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index f204c903fbc5..a2732dac83e9 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -807,6 +807,9 @@ def test_eagle_with_sliding_window(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) manager.allocate_slots(req, len(token_ids), computed_blocks) + # record the block hash of the first block in the request for later use + block_hash_first_block = manager.req_to_block_hashes[req.request_id][0] + assert block_hash_first_block is not None manager.free(req) # New request with Eagle enabled @@ -817,8 +820,6 @@ def test_eagle_with_sliding_window(): assert num_tokens == 1 * block_size # Evict the first block in the request - block_hash_first_block = computed_blocks[0].block_hash - assert block_hash_first_block is not None assert manager.block_pool.get_cached_block( block_hash_first_block) is not None manager.block_pool.cached_block_hash_to_block.pop(block_hash_first_block) diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py index ef136f678bf2..f04eedf42662 100644 --- a/vllm/v1/core/specialized_manager.py +++ b/vllm/v1/core/specialized_manager.py @@ -105,14 +105,12 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, # -1 since the input token itself is also included in the window self.sliding_window_contiguous_blocks = cdiv( (kv_cache_spec.sliding_window - 1), self.block_size) - print("conti b eagle", self.sliding_window_contiguous_blocks) if self.use_eagle: # Need to drop the last matched block if eagle is enabled. For # sliding window layer, we achieve this by increasing the number of # contiguous blocks needed for prefix cache hit by one and dropping # the last matched block. self.sliding_window_contiguous_blocks += 1 - print("conti b eagle", self.sliding_window_contiguous_blocks) self._null_block = block_pool.null_block def find_longest_cache_hit( From 1ed6f4e528ccc264b262b97271735f77dbe1a430 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 30 Apr 2025 03:24:08 -0700 Subject: [PATCH 4/4] fix test Signed-off-by: Chen Zhang --- tests/v1/core/test_specialized_manager.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index 9b4ab5fa8b12..595c8608fc64 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -19,7 +19,9 @@ def test_sliding_window_possible_cached_prefix(): ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) - manager = SlidingWindowManager(sliding_window_spec, block_pool) + manager = SlidingWindowManager(sliding_window_spec, + block_pool, + use_eagle=False) def run_one_case(block_is_cached, expect_length): block_hash_list = [ @@ -79,7 +81,9 @@ def test_sliding_window_remove_skipped_blocks(): block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) - manager = SlidingWindowManager(sliding_window_spec, block_pool) + manager = SlidingWindowManager(sliding_window_spec, + block_pool, + use_eagle=False) null_block_id = block_pool.null_block.block_id