Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 112 additions & 76 deletions vllm/v1/core/single_type_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,18 +243,53 @@ def find_longest_cache_hit(

raise NotImplementedError

@abstractmethod
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
"""
Remove the blocks that are no longer needed from `blocks` and free the
blocks. The removed blocks should be replaced by null_block.
Need to be customized for each attention type.
Remove and free the blocks that are no longer needed for attention computation.
The removed blocks should be replaced by null_block.

This function depends on `get_num_skipped_tokens`, which need to be implemented
differently for each attention type.

Args:
request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed.
"""
raise NotImplementedError
# Remove the blocks that will be skipped during attention computation.
num_skipped_tokens = self.get_num_skipped_tokens(num_computed_tokens)
if num_skipped_tokens <= 0:
# This indicates that ALL tokens are inside attention window.
# Thus we do not need to free any blocks outside attention window.
# A typical case is full attention that we never free any token
# before the request is finished.
return
num_skipped_blocks = num_skipped_tokens // self.block_size
blocks = self.req_to_blocks[request_id]
removed_blocks: list[KVCacheBlock] = []
# Because the block starts from index 0, the num_skipped_block-th block
# corresponds to index num_skipped_blocks - 1.
for i in range(num_skipped_blocks - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)

def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
"""
Get the number of tokens that will be skipped for attention computation.

Args:
num_computed_tokens: The number of tokens that have been computed.

Returns:
The number of tokens that will be skipped for attention computation.
"""
# The default behavior is to not skip any tokens.
return 0


class FullAttentionManager(SingleTypeKVCacheManager):
Expand Down Expand Up @@ -298,10 +333,6 @@ def find_longest_cache_hit(
computed.pop()
return computed_blocks

def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
# No need to remove blocks for full attention.
pass

def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
blocks = self.req_to_blocks[running_request_id]
num_common_blocks = 0
Expand Down Expand Up @@ -389,28 +420,33 @@ def find_longest_cache_hit(
computed.pop()
return computed_blocks

def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
# Remove the blocks that are no longer be in the sliding window and
# skipped during the attention computation.
last_useful_token = num_computed_tokens - self.sliding_window + 1
last_useful_block = last_useful_token // self.block_size
if last_useful_block <= 0:
# Early return if tokens are not enough to fill the sliding window
return
blocks = self.req_to_blocks[request_id]
if blocks[last_useful_block - 1] == self._null_block:
# Early return if there are no blocks to remove
return
removed_blocks: list[KVCacheBlock] = []
for i in range(last_useful_block - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
"""
Get the number of tokens that will be skipped for attention computation.

For sliding window, this corresponds to the tokens that are prior to
the current sliding window.

Example:
sliding_window=4, num_computed_tokens=7

Tokens: [ 0 1 2 3 4 5 6 7 ]
| ---- computed -----|
^ next token to be computed
|-----------| sliding window for next token
|--skipped---|

The current window contains tokens 4~7. Tokens 0~3 will be skipped for
attention computation since they are outside the sliding window.
Thus, get_num_skipped_tokens(7) == 4.

Args:
num_computed_tokens: The number of tokens that have been computed.

Returns:
The number of tokens that will be skipped for attention computation.
"""
return num_computed_tokens - self.sliding_window + 1

def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
Expand Down Expand Up @@ -511,40 +547,51 @@ def find_longest_cache_hit(
break
return computed_blocks

def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
# Remove the blocks that are no longer be in the chunked attention
# window and skipped during the attention computation.

# [chunk 0][chunk 1]local_attention_start_idx ... current
# we computed previous number of chunks to get the idx of
# current chunk window starting offset,
# e.g. for computed 1024 tokens, the 1024th token (0 indexed)
# is in the second chunk, there are 1 prev chunk, the start idx
# is 1024. for 1023, it will be 0.
num_cached_block = self.num_cached_block.get(request_id, 0)
local_attention_start_idx = (
(num_computed_tokens)
// self.attention_chunk_size
* self.attention_chunk_size
)
first_useful_block_idx = local_attention_start_idx // self.block_size
if num_cached_block > 0:
# Make sure we don't delete the last cached block
first_useful_block_idx = min(first_useful_block_idx, num_cached_block - 1)
# if block size = 128, 0 -> block 0, 1024 (= 128 * 8) ->
# block 8, 372 (= 128 * 2 + 116) -> block 2
blocks = self.req_to_blocks[request_id]
removed_blocks: list[KVCacheBlock] = []
# we need to keep the last block to get the previous hash key
for i in range(first_useful_block_idx - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
"""
Get the number of tokens that will be skipped for attention computation.

For chunked local attention, this corresponds to the tokens that are on
the left side of the current chunk.

Example 1:
chunk size = 8, num_computed_tokens = 13
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
| ----- computed ---------------|
^^ next token to be computed
|----------------| <-- attention window for
next token
|--- skipped -----|
Output: get_num_skipped_tokens(13) == 8

Example 2:
chunk size = 8, num_computed_tokens = 8
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
| --- computed ---|
^ next token to be computed
|--| <-- attention window for next token
| --- skipped ----|
Output: get_num_skipped_tokens(8) == 8

Example 3:
chunk size = 8, num_computed_tokens = 7
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
|---computed---|
^ next token to be computed
|-----------------| <-- attention window for next token
no token should be skipped.
Output: get_num_skipped_tokens(7) == 0

Args:
num_computed_tokens: The number of tokens that have been computed.

Returns:
The number of tokens that will be skipped for attention computation.
"""
num_skipped_tokens = (
num_computed_tokens // self.attention_chunk_size
) * self.attention_chunk_size
return num_skipped_tokens

def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
Expand Down Expand Up @@ -590,12 +637,6 @@ def find_longest_cache_hit(

return computed_blocks

def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
# Here unused blocks may be freed up for running requests.
# TODO(@s3woz) Free up all blocks that aren't needed by Mamba2
# (for which find_longest_cache_hit returns block_pool.null_block)
pass

def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
cascade attention is not supported by mamba
Expand Down Expand Up @@ -676,11 +717,6 @@ def find_longest_cache_hit(
# Return empty blocks to indicate no cache hits
raise NotImplementedError("CrossAttentionManager does not support caching")

def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
# Cross-attention blocks represent encoder states which are needed
# for the entire decoding process, so no blocks should be skipped
pass


spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
Expand Down