Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 60 additions & 4 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
hash_block_tokens)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
KVCacheGroupSpec, SlidingWindowSpec)


def make_request(request_id,
Expand Down Expand Up @@ -745,11 +745,11 @@ def test_eagle_enabled_removes_last_block():
req_eagle = make_request("eagle_divisible", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)

# Should retain 2 blocks:
# Should retain 1 block:
# 1. Original 3 blocks → pop last hash → 2 matched blocks
# 2. last_block_hash is not NoneEagle pop is not SKIPPED
# 2. drop last matched block1 remaining block
assert len(computed_blocks) == 1
assert num_tokens == 1 * block_size # 32 tokens
assert num_tokens == 1 * block_size # 16 tokens


def test_eagle_with_partial_blocks():
Expand All @@ -776,3 +776,59 @@ def test_eagle_with_partial_blocks():
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks) == 1
assert num_tokens == 1 * block_size


def test_eagle_with_sliding_window():
"""Test Eagle behavior with sliding window."""
block_size = 16
sliding_window_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
use_mla=False,
)
manager = KVCacheManager(
KVCacheConfig(
num_blocks=10,
tensors={},
kv_cache_groups=[KVCacheGroupSpec(['layer'], sliding_window_spec)],
),
max_model_len=8192,
enable_caching=True,
use_eagle=True,
)

# 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5)
req = make_request("partial_block_test", token_ids)

# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks)
# record the block hash of the first block in the request for later use
block_hash_first_block = manager.req_to_block_hashes[req.request_id][0]
assert block_hash_first_block is not None
manager.free(req)

# New request with Eagle enabled
req_eagle = make_request("partial_eagle", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks) == 1
assert num_tokens == 1 * block_size

# Evict the first block in the request
assert manager.block_pool.get_cached_block(
block_hash_first_block) is not None
manager.block_pool.cached_block_hash_to_block.pop(block_hash_first_block)

# New request
req_after_evict = make_request("partial_eagle_after_evict", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict)
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# not considered. But after dropping the last matched block due to eagle,
# there will be no matched prefix.
assert len(computed_blocks) == 0
assert num_tokens == 0
8 changes: 6 additions & 2 deletions tests/v1/core/test_specialized_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def test_sliding_window_possible_cached_prefix():
)

block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
manager = SlidingWindowManager(sliding_window_spec, block_pool)
manager = SlidingWindowManager(sliding_window_spec,
block_pool,
use_eagle=False)

def run_one_case(block_is_cached, expect_length):
block_hash_list = [
Expand Down Expand Up @@ -79,7 +81,9 @@ def test_sliding_window_remove_skipped_blocks():

block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)

manager = SlidingWindowManager(sliding_window_spec, block_pool)
manager = SlidingWindowManager(sliding_window_spec,
block_pool,
use_eagle=False)

null_block_id = block_pool.null_block.block_id

Expand Down
8 changes: 1 addition & 7 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
self.specialized_manager = get_specialized_manager(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
use_eagle=self.use_eagle,
)

# Mapping from request ID to blocks to track the blocks allocated
Expand Down Expand Up @@ -137,13 +138,6 @@ def get_computed_blocks(
computed_blocks = (
self.specialized_manager.find_longest_cache_hit(block_hashes))

if self.use_eagle and len(computed_blocks) > 0:
# Drop the last matched block if (1) eagle is enabled and
# (2) there is a cache hit.
# This is to recompute the last block to get the required
# hidden states for eagle drafting head.
computed_blocks.pop()

if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.queries += len(block_hashes)
Expand Down
39 changes: 29 additions & 10 deletions vllm/v1/core/specialized_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
self,
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
use_eagle: bool,
) -> None:
"""
Initializes the SpecializedManager.
Expand All @@ -30,12 +31,17 @@ def __init__(
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool

# Needs special handling for find_longest_cache_hit if eagle is enabled
self.use_eagle = use_eagle

@abstractmethod
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
"""
Get the longest cache hit prefix of the blocks. If no cache hit is
found, return an empty list.
found, return an empty list. if eagle is enabled, drop the last matched
block to force recompute the last block to get the required hidden
states for eagle drafting head.

Args:
block_hashes: The block hashes of the request.
Expand Down Expand Up @@ -79,6 +85,8 @@ def find_longest_cache_hit(
computed_blocks.append(cached_block)
else:
break
if self.use_eagle and len(computed_blocks) > 0:
computed_blocks.pop()
return computed_blocks

def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
Expand All @@ -89,14 +97,20 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock],

class SlidingWindowManager(SpecializedManager):

def __init__(self, kv_cache_spec: SlidingWindowSpec,
block_pool: BlockPool):
super().__init__(kv_cache_spec, block_pool)
def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool,
use_eagle: bool):
super().__init__(kv_cache_spec, block_pool, use_eagle)
self.sliding_window = kv_cache_spec.sliding_window
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
self.sliding_window_contiguous_blocks = cdiv(
(kv_cache_spec.sliding_window - 1), self.block_size)
if self.use_eagle:
# Need to drop the last matched block if eagle is enabled. For
# sliding window layer, we achieve this by increasing the number of
# contiguous blocks needed for prefix cache hit by one and dropping
# the last matched block.
self.sliding_window_contiguous_blocks += 1
self._null_block = block_pool.null_block

def find_longest_cache_hit(
Expand All @@ -109,6 +123,7 @@ def find_longest_cache_hit(
computed_blocks = [self._null_block] * len(block_hashes)
num_contiguous_blocks = 0

match_found = False
# Search from right to left and early stop when a match is found.
for i in range(len(block_hashes) - 1, -1, -1):
if cached_block := self.block_pool.get_cached_block(
Expand All @@ -121,12 +136,16 @@ def find_longest_cache_hit(
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
del computed_blocks[i + num_contiguous_blocks:]
return computed_blocks
match_found = True
break
else:
num_contiguous_blocks = 0
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
del computed_blocks[num_contiguous_blocks:]
if not match_found:
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
del computed_blocks[num_contiguous_blocks:]
if self.use_eagle and len(computed_blocks) > 0:
computed_blocks.pop()
return computed_blocks

def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
Expand Down Expand Up @@ -155,7 +174,7 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock],


def get_specialized_manager(kv_cache_spec: KVCacheSpec,
block_pool: BlockPool) -> SpecializedManager:
**kwargs) -> SpecializedManager:
manager_class = spec_manager_map[type(kv_cache_spec)]
manager = manager_class(kv_cache_spec, block_pool)
manager = manager_class(kv_cache_spec, **kwargs)
return manager