diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 8f14fb189470..14ac83028ee4 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -243,18 +243,53 @@ def find_longest_cache_hit( raise NotImplementedError - @abstractmethod def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: """ - Remove the blocks that are no longer needed from `blocks` and free the - blocks. The removed blocks should be replaced by null_block. - Need to be customized for each attention type. + Remove and free the blocks that are no longer needed for attention computation. + The removed blocks should be replaced by null_block. + + This function depends on `get_num_skipped_tokens`, which need to be implemented + differently for each attention type. Args: request_id: The request ID. num_computed_tokens: The number of tokens that have been computed. """ - raise NotImplementedError + # Remove the blocks that will be skipped during attention computation. + num_skipped_tokens = self.get_num_skipped_tokens(num_computed_tokens) + if num_skipped_tokens <= 0: + # This indicates that ALL tokens are inside attention window. + # Thus we do not need to free any blocks outside attention window. + # A typical case is full attention that we never free any token + # before the request is finished. + return + num_skipped_blocks = num_skipped_tokens // self.block_size + blocks = self.req_to_blocks[request_id] + removed_blocks: list[KVCacheBlock] = [] + # Because the block starts from index 0, the num_skipped_block-th block + # corresponds to index num_skipped_blocks - 1. + for i in range(num_skipped_blocks - 1, -1, -1): + if blocks[i] == self._null_block: + # If the block is already a null block, the blocks before it + # should also have been set to null blocks by the previous calls + # to this function. + break + removed_blocks.append(blocks[i]) + blocks[i] = self._null_block + self.block_pool.free_blocks(removed_blocks) + + def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: + """ + Get the number of tokens that will be skipped for attention computation. + + Args: + num_computed_tokens: The number of tokens that have been computed. + + Returns: + The number of tokens that will be skipped for attention computation. + """ + # The default behavior is to not skip any tokens. + return 0 class FullAttentionManager(SingleTypeKVCacheManager): @@ -298,10 +333,6 @@ def find_longest_cache_hit( computed.pop() return computed_blocks - def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: - # No need to remove blocks for full attention. - pass - def get_num_common_prefix_blocks(self, running_request_id: str) -> int: blocks = self.req_to_blocks[running_request_id] num_common_blocks = 0 @@ -389,28 +420,33 @@ def find_longest_cache_hit( computed.pop() return computed_blocks - def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: - # Remove the blocks that are no longer be in the sliding window and - # skipped during the attention computation. - last_useful_token = num_computed_tokens - self.sliding_window + 1 - last_useful_block = last_useful_token // self.block_size - if last_useful_block <= 0: - # Early return if tokens are not enough to fill the sliding window - return - blocks = self.req_to_blocks[request_id] - if blocks[last_useful_block - 1] == self._null_block: - # Early return if there are no blocks to remove - return - removed_blocks: list[KVCacheBlock] = [] - for i in range(last_useful_block - 1, -1, -1): - if blocks[i] == self._null_block: - # If the block is already a null block, the blocks before it - # should also have been set to null blocks by the previous calls - # to this function. - break - removed_blocks.append(blocks[i]) - blocks[i] = self._null_block - self.block_pool.free_blocks(removed_blocks) + def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: + """ + Get the number of tokens that will be skipped for attention computation. + + For sliding window, this corresponds to the tokens that are prior to + the current sliding window. + + Example: + sliding_window=4, num_computed_tokens=7 + + Tokens: [ 0 1 2 3 4 5 6 7 ] + | ---- computed -----| + ^ next token to be computed + |-----------| sliding window for next token + |--skipped---| + + The current window contains tokens 4~7. Tokens 0~3 will be skipped for + attention computation since they are outside the sliding window. + Thus, get_num_skipped_tokens(7) == 4. + + Args: + num_computed_tokens: The number of tokens that have been computed. + + Returns: + The number of tokens that will be skipped for attention computation. + """ + return num_computed_tokens - self.sliding_window + 1 def get_num_common_prefix_blocks(self, running_request_id: str) -> int: """ @@ -511,40 +547,51 @@ def find_longest_cache_hit( break return computed_blocks - def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: - # Remove the blocks that are no longer be in the chunked attention - # window and skipped during the attention computation. - - # [chunk 0][chunk 1]local_attention_start_idx ... current - # we computed previous number of chunks to get the idx of - # current chunk window starting offset, - # e.g. for computed 1024 tokens, the 1024th token (0 indexed) - # is in the second chunk, there are 1 prev chunk, the start idx - # is 1024. for 1023, it will be 0. - num_cached_block = self.num_cached_block.get(request_id, 0) - local_attention_start_idx = ( - (num_computed_tokens) - // self.attention_chunk_size - * self.attention_chunk_size - ) - first_useful_block_idx = local_attention_start_idx // self.block_size - if num_cached_block > 0: - # Make sure we don't delete the last cached block - first_useful_block_idx = min(first_useful_block_idx, num_cached_block - 1) - # if block size = 128, 0 -> block 0, 1024 (= 128 * 8) -> - # block 8, 372 (= 128 * 2 + 116) -> block 2 - blocks = self.req_to_blocks[request_id] - removed_blocks: list[KVCacheBlock] = [] - # we need to keep the last block to get the previous hash key - for i in range(first_useful_block_idx - 1, -1, -1): - if blocks[i] == self._null_block: - # If the block is already a null block, the blocks before it - # should also have been set to null blocks by the previous calls - # to this function. - break - removed_blocks.append(blocks[i]) - blocks[i] = self._null_block - self.block_pool.free_blocks(removed_blocks) + def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: + """ + Get the number of tokens that will be skipped for attention computation. + + For chunked local attention, this corresponds to the tokens that are on + the left side of the current chunk. + + Example 1: + chunk size = 8, num_computed_tokens = 13 + Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ... + | ----- computed ---------------| + ^^ next token to be computed + |----------------| <-- attention window for + next token + |--- skipped -----| + Output: get_num_skipped_tokens(13) == 8 + + Example 2: + chunk size = 8, num_computed_tokens = 8 + Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ... + | --- computed ---| + ^ next token to be computed + |--| <-- attention window for next token + | --- skipped ----| + Output: get_num_skipped_tokens(8) == 8 + + Example 3: + chunk size = 8, num_computed_tokens = 7 + Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ... + |---computed---| + ^ next token to be computed + |-----------------| <-- attention window for next token + no token should be skipped. + Output: get_num_skipped_tokens(7) == 0 + + Args: + num_computed_tokens: The number of tokens that have been computed. + + Returns: + The number of tokens that will be skipped for attention computation. + """ + num_skipped_tokens = ( + num_computed_tokens // self.attention_chunk_size + ) * self.attention_chunk_size + return num_skipped_tokens def get_num_common_prefix_blocks(self, running_request_id: str) -> int: """ @@ -590,12 +637,6 @@ def find_longest_cache_hit( return computed_blocks - def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: - # Here unused blocks may be freed up for running requests. - # TODO(@s3woz) Free up all blocks that aren't needed by Mamba2 - # (for which find_longest_cache_hit returns block_pool.null_block) - pass - def get_num_common_prefix_blocks(self, running_request_id: str) -> int: """ cascade attention is not supported by mamba @@ -676,11 +717,6 @@ def find_longest_cache_hit( # Return empty blocks to indicate no cache hits raise NotImplementedError("CrossAttentionManager does not support caching") - def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: - # Cross-attention blocks represent encoder states which are needed - # for the entire decoding process, so no blocks should be skipped - pass - spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager,