Skip to content

Commit 81ecf42

Browse files
authored
[v1][Spec Decode] Make sliding window compatible with eagle prefix caching (#17398)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
1 parent 42d9a2c commit 81ecf42

File tree

4 files changed

+96
-23
lines changed

4 files changed

+96
-23
lines changed

tests/v1/core/test_prefix_caching.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
1616
hash_block_tokens)
1717
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
18-
KVCacheGroupSpec)
18+
KVCacheGroupSpec, SlidingWindowSpec)
1919

2020

2121
def make_request(request_id,
@@ -863,11 +863,11 @@ def test_eagle_enabled_removes_last_block():
863863
req_eagle = make_request("eagle_divisible", token_ids)
864864
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
865865

866-
# Should retain 2 blocks:
866+
# Should retain 1 block:
867867
# 1. Original 3 blocks → pop last hash → 2 matched blocks
868-
# 2. last_block_hash is not NoneEagle pop is not SKIPPED
868+
# 2. drop last matched block1 remaining block
869869
assert len(computed_blocks) == 1
870-
assert num_tokens == 1 * block_size # 32 tokens
870+
assert num_tokens == 1 * block_size # 16 tokens
871871

872872

873873
def test_eagle_with_partial_blocks():
@@ -894,3 +894,59 @@ def test_eagle_with_partial_blocks():
894894
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
895895
assert len(computed_blocks) == 1
896896
assert num_tokens == 1 * block_size
897+
898+
899+
def test_eagle_with_sliding_window():
900+
"""Test Eagle behavior with sliding window."""
901+
block_size = 16
902+
sliding_window_spec = SlidingWindowSpec(
903+
block_size=block_size,
904+
num_kv_heads=1,
905+
head_size=1,
906+
dtype=torch.float32,
907+
sliding_window=block_size,
908+
use_mla=False,
909+
)
910+
manager = KVCacheManager(
911+
KVCacheConfig(
912+
num_blocks=10,
913+
tensors={},
914+
kv_cache_groups=[KVCacheGroupSpec(['layer'], sliding_window_spec)],
915+
),
916+
max_model_len=8192,
917+
enable_caching=True,
918+
use_eagle=True,
919+
)
920+
921+
# 2 full blocks + 5 tokens (non-divisible length)
922+
token_ids = [0] * (2 * block_size + 5)
923+
req = make_request("partial_block_test", token_ids)
924+
925+
# Prime the cache
926+
computed_blocks, _ = manager.get_computed_blocks(req)
927+
manager.allocate_slots(req, len(token_ids), computed_blocks)
928+
# record the block hash of the first block in the request for later use
929+
block_hash_first_block = manager.req_to_block_hashes[req.request_id][0]
930+
assert block_hash_first_block is not None
931+
manager.free(req)
932+
933+
# New request with Eagle enabled
934+
req_eagle = make_request("partial_eagle", token_ids)
935+
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
936+
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
937+
assert len(computed_blocks) == 1
938+
assert num_tokens == 1 * block_size
939+
940+
# Evict the first block in the request
941+
assert manager.block_pool.get_cached_block(
942+
block_hash_first_block) is not None
943+
manager.block_pool.cached_block_hash_to_block.pop(block_hash_first_block)
944+
945+
# New request
946+
req_after_evict = make_request("partial_eagle_after_evict", token_ids)
947+
computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict)
948+
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
949+
# not considered. But after dropping the last matched block due to eagle,
950+
# there will be no matched prefix.
951+
assert len(computed_blocks) == 0
952+
assert num_tokens == 0

tests/v1/core/test_specialized_manager.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def test_sliding_window_possible_cached_prefix():
1919
)
2020

2121
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
22-
manager = SlidingWindowManager(sliding_window_spec, block_pool)
22+
manager = SlidingWindowManager(sliding_window_spec,
23+
block_pool,
24+
use_eagle=False)
2325

2426
def run_one_case(block_is_cached, expect_length):
2527
block_hash_list = [
@@ -79,7 +81,9 @@ def test_sliding_window_remove_skipped_blocks():
7981

8082
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
8183

82-
manager = SlidingWindowManager(sliding_window_spec, block_pool)
84+
manager = SlidingWindowManager(sliding_window_spec,
85+
block_pool,
86+
use_eagle=False)
8387

8488
null_block_id = block_pool.null_block.block_id
8589

vllm/v1/core/kv_cache_manager.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(
5252
self.specialized_manager = get_specialized_manager(
5353
kv_cache_spec=kv_cache_spec,
5454
block_pool=self.block_pool,
55+
use_eagle=self.use_eagle,
5556
)
5657

5758
# Mapping from request ID to blocks to track the blocks allocated
@@ -141,13 +142,6 @@ def get_computed_blocks(
141142
computed_blocks = (
142143
self.specialized_manager.find_longest_cache_hit(block_hashes))
143144

144-
if self.use_eagle and len(computed_blocks) > 0:
145-
# Drop the last matched block if (1) eagle is enabled and
146-
# (2) there is a cache hit.
147-
# This is to recompute the last block to get the required
148-
# hidden states for eagle drafting head.
149-
computed_blocks.pop()
150-
151145
if self.log_stats:
152146
assert self.prefix_cache_stats is not None
153147
self.prefix_cache_stats.queries += len(block_hashes)

vllm/v1/core/specialized_manager.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(
1818
self,
1919
kv_cache_spec: KVCacheSpec,
2020
block_pool: BlockPool,
21+
use_eagle: bool,
2122
) -> None:
2223
"""
2324
Initializes the SpecializedManager.
@@ -30,12 +31,17 @@ def __init__(
3031
self.kv_cache_spec = kv_cache_spec
3132
self.block_pool = block_pool
3233

34+
# Needs special handling for find_longest_cache_hit if eagle is enabled
35+
self.use_eagle = use_eagle
36+
3337
@abstractmethod
3438
def find_longest_cache_hit(
3539
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
3640
"""
3741
Get the longest cache hit prefix of the blocks. If no cache hit is
38-
found, return an empty list.
42+
found, return an empty list. if eagle is enabled, drop the last matched
43+
block to force recompute the last block to get the required hidden
44+
states for eagle drafting head.
3945
4046
Args:
4147
block_hashes: The block hashes of the request.
@@ -79,6 +85,8 @@ def find_longest_cache_hit(
7985
computed_blocks.append(cached_block)
8086
else:
8187
break
88+
if self.use_eagle and len(computed_blocks) > 0:
89+
computed_blocks.pop()
8290
return computed_blocks
8391

8492
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
@@ -89,14 +97,20 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
8997

9098
class SlidingWindowManager(SpecializedManager):
9199

92-
def __init__(self, kv_cache_spec: SlidingWindowSpec,
93-
block_pool: BlockPool):
94-
super().__init__(kv_cache_spec, block_pool)
100+
def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool,
101+
use_eagle: bool):
102+
super().__init__(kv_cache_spec, block_pool, use_eagle)
95103
self.sliding_window = kv_cache_spec.sliding_window
96104
# The number of contiguous blocks needed for prefix cache hit.
97105
# -1 since the input token itself is also included in the window
98106
self.sliding_window_contiguous_blocks = cdiv(
99107
(kv_cache_spec.sliding_window - 1), self.block_size)
108+
if self.use_eagle:
109+
# Need to drop the last matched block if eagle is enabled. For
110+
# sliding window layer, we achieve this by increasing the number of
111+
# contiguous blocks needed for prefix cache hit by one and dropping
112+
# the last matched block.
113+
self.sliding_window_contiguous_blocks += 1
100114
self._null_block = block_pool.null_block
101115

102116
def find_longest_cache_hit(
@@ -109,6 +123,7 @@ def find_longest_cache_hit(
109123
computed_blocks = [self._null_block] * len(block_hashes)
110124
num_contiguous_blocks = 0
111125

126+
match_found = False
112127
# Search from right to left and early stop when a match is found.
113128
for i in range(len(block_hashes) - 1, -1, -1):
114129
if cached_block := self.block_pool.get_cached_block(
@@ -121,12 +136,16 @@ def find_longest_cache_hit(
121136
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
122137
# when sliding_window_contiguous_blocks=2.
123138
del computed_blocks[i + num_contiguous_blocks:]
124-
return computed_blocks
139+
match_found = True
140+
break
125141
else:
126142
num_contiguous_blocks = 0
127-
# The first `num_contiguous_blocks` is a cache hit even if
128-
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
129-
del computed_blocks[num_contiguous_blocks:]
143+
if not match_found:
144+
# The first `num_contiguous_blocks` is a cache hit even if
145+
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
146+
del computed_blocks[num_contiguous_blocks:]
147+
if self.use_eagle and len(computed_blocks) > 0:
148+
computed_blocks.pop()
130149
return computed_blocks
131150

132151
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
@@ -155,7 +174,7 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
155174

156175

157176
def get_specialized_manager(kv_cache_spec: KVCacheSpec,
158-
block_pool: BlockPool) -> SpecializedManager:
177+
**kwargs) -> SpecializedManager:
159178
manager_class = spec_manager_map[type(kv_cache_spec)]
160-
manager = manager_class(kv_cache_spec, block_pool)
179+
manager = manager_class(kv_cache_spec, **kwargs)
161180
return manager

0 commit comments

Comments
 (0)