diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index b2e8ff61450c..a2732dac83e9 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -14,7 +14,7 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, hash_block_tokens) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec) + KVCacheGroupSpec, SlidingWindowSpec) def make_request(request_id, @@ -745,11 +745,11 @@ def test_eagle_enabled_removes_last_block(): req_eagle = make_request("eagle_divisible", token_ids) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) - # Should retain 2 blocks: + # Should retain 1 block: # 1. Original 3 blocks → pop last hash → 2 matched blocks - # 2. last_block_hash is not None → Eagle pop is not SKIPPED + # 2. drop last matched block → 1 remaining block assert len(computed_blocks) == 1 - assert num_tokens == 1 * block_size # 32 tokens + assert num_tokens == 1 * block_size # 16 tokens def test_eagle_with_partial_blocks(): @@ -776,3 +776,59 @@ def test_eagle_with_partial_blocks(): # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks) == 1 assert num_tokens == 1 * block_size + + +def test_eagle_with_sliding_window(): + """Test Eagle behavior with sliding window.""" + block_size = 16 + sliding_window_spec = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=block_size, + use_mla=False, + ) + manager = KVCacheManager( + KVCacheConfig( + num_blocks=10, + tensors={}, + kv_cache_groups=[KVCacheGroupSpec(['layer'], sliding_window_spec)], + ), + max_model_len=8192, + enable_caching=True, + use_eagle=True, + ) + + # 2 full blocks + 5 tokens (non-divisible length) + token_ids = [0] * (2 * block_size + 5) + req = make_request("partial_block_test", token_ids) + + # Prime the cache + computed_blocks, _ = manager.get_computed_blocks(req) + manager.allocate_slots(req, len(token_ids), computed_blocks) + # record the block hash of the first block in the request for later use + block_hash_first_block = manager.req_to_block_hashes[req.request_id][0] + assert block_hash_first_block is not None + manager.free(req) + + # New request with Eagle enabled + req_eagle = make_request("partial_eagle", token_ids) + computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) + # Original match: 2 full blocks → Eagle removes 1 → 1 remaining + assert len(computed_blocks) == 1 + assert num_tokens == 1 * block_size + + # Evict the first block in the request + assert manager.block_pool.get_cached_block( + block_hash_first_block) is not None + manager.block_pool.cached_block_hash_to_block.pop(block_hash_first_block) + + # New request + req_after_evict = make_request("partial_eagle_after_evict", token_ids) + computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict) + # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is + # not considered. But after dropping the last matched block due to eagle, + # there will be no matched prefix. + assert len(computed_blocks) == 0 + assert num_tokens == 0 diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index 9b4ab5fa8b12..595c8608fc64 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -19,7 +19,9 @@ def test_sliding_window_possible_cached_prefix(): ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) - manager = SlidingWindowManager(sliding_window_spec, block_pool) + manager = SlidingWindowManager(sliding_window_spec, + block_pool, + use_eagle=False) def run_one_case(block_is_cached, expect_length): block_hash_list = [ @@ -79,7 +81,9 @@ def test_sliding_window_remove_skipped_blocks(): block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) - manager = SlidingWindowManager(sliding_window_spec, block_pool) + manager = SlidingWindowManager(sliding_window_spec, + block_pool, + use_eagle=False) null_block_id = block_pool.null_block.block_id diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 0830d8433d89..2d959aaff45e 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -48,6 +48,7 @@ def __init__( self.specialized_manager = get_specialized_manager( kv_cache_spec=kv_cache_spec, block_pool=self.block_pool, + use_eagle=self.use_eagle, ) # Mapping from request ID to blocks to track the blocks allocated @@ -137,13 +138,6 @@ def get_computed_blocks( computed_blocks = ( self.specialized_manager.find_longest_cache_hit(block_hashes)) - if self.use_eagle and len(computed_blocks) > 0: - # Drop the last matched block if (1) eagle is enabled and - # (2) there is a cache hit. - # This is to recompute the last block to get the required - # hidden states for eagle drafting head. - computed_blocks.pop() - if self.log_stats: assert self.prefix_cache_stats is not None self.prefix_cache_stats.queries += len(block_hashes) diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py index 7a8a98361c7e..f04eedf42662 100644 --- a/vllm/v1/core/specialized_manager.py +++ b/vllm/v1/core/specialized_manager.py @@ -18,6 +18,7 @@ def __init__( self, kv_cache_spec: KVCacheSpec, block_pool: BlockPool, + use_eagle: bool, ) -> None: """ Initializes the SpecializedManager. @@ -30,12 +31,17 @@ def __init__( self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool + # Needs special handling for find_longest_cache_hit if eagle is enabled + self.use_eagle = use_eagle + @abstractmethod def find_longest_cache_hit( self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: """ Get the longest cache hit prefix of the blocks. If no cache hit is - found, return an empty list. + found, return an empty list. if eagle is enabled, drop the last matched + block to force recompute the last block to get the required hidden + states for eagle drafting head. Args: block_hashes: The block hashes of the request. @@ -79,6 +85,8 @@ def find_longest_cache_hit( computed_blocks.append(cached_block) else: break + if self.use_eagle and len(computed_blocks) > 0: + computed_blocks.pop() return computed_blocks def remove_skipped_blocks(self, blocks: list[KVCacheBlock], @@ -89,14 +97,20 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock], class SlidingWindowManager(SpecializedManager): - def __init__(self, kv_cache_spec: SlidingWindowSpec, - block_pool: BlockPool): - super().__init__(kv_cache_spec, block_pool) + def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, + use_eagle: bool): + super().__init__(kv_cache_spec, block_pool, use_eagle) self.sliding_window = kv_cache_spec.sliding_window # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window self.sliding_window_contiguous_blocks = cdiv( (kv_cache_spec.sliding_window - 1), self.block_size) + if self.use_eagle: + # Need to drop the last matched block if eagle is enabled. For + # sliding window layer, we achieve this by increasing the number of + # contiguous blocks needed for prefix cache hit by one and dropping + # the last matched block. + self.sliding_window_contiguous_blocks += 1 self._null_block = block_pool.null_block def find_longest_cache_hit( @@ -109,6 +123,7 @@ def find_longest_cache_hit( computed_blocks = [self._null_block] * len(block_hashes) num_contiguous_blocks = 0 + match_found = False # Search from right to left and early stop when a match is found. for i in range(len(block_hashes) - 1, -1, -1): if cached_block := self.block_pool.get_cached_block( @@ -121,12 +136,16 @@ def find_longest_cache_hit( # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] # when sliding_window_contiguous_blocks=2. del computed_blocks[i + num_contiguous_blocks:] - return computed_blocks + match_found = True + break else: num_contiguous_blocks = 0 - # The first `num_contiguous_blocks` is a cache hit even if - # `num_contiguous_blocks < sliding_window_contiguous_blocks`. - del computed_blocks[num_contiguous_blocks:] + if not match_found: + # The first `num_contiguous_blocks` is a cache hit even if + # `num_contiguous_blocks < sliding_window_contiguous_blocks`. + del computed_blocks[num_contiguous_blocks:] + if self.use_eagle and len(computed_blocks) > 0: + computed_blocks.pop() return computed_blocks def remove_skipped_blocks(self, blocks: list[KVCacheBlock], @@ -155,7 +174,7 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock], def get_specialized_manager(kv_cache_spec: KVCacheSpec, - block_pool: BlockPool) -> SpecializedManager: + **kwargs) -> SpecializedManager: manager_class = spec_manager_map[type(kv_cache_spec)] - manager = manager_class(kv_cache_spec, block_pool) + manager = manager_class(kv_cache_spec, **kwargs) return manager