From 9ca160ece6b0aa4c83ba5c6c755f5865e9f5abde Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 22 Sep 2025 16:15:07 -0700 Subject: [PATCH 1/7] Introduce get_last_useful_token to get the last useful token inside the attention window Signed-off-by: KuntaiDu --- vllm/v1/core/single_type_kv_cache_manager.py | 119 +++++++++---------- 1 file changed, 57 insertions(+), 62 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index d27239164b0d..b6476d25bb3f 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -231,19 +231,47 @@ def find_longest_cache_hit( raise NotImplementedError - @abstractmethod def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: """ Remove the blocks that are no longer needed from `blocks` and free the blocks. The removed blocks should be replaced by null_block. Need to be customized for each attention type. - Args: request_id: The request ID. num_computed_tokens: The number of tokens that have been computed. """ - raise NotImplementedError + # Remove the blocks that are no longer be in the sliding window and + # skipped during the attention computation. + last_useful_token = self.get_last_useful_token(num_computed_tokens) + if last_useful_token <= 0: + # This indicates that ALL tokens are inside attention window. + # Thus we do not need to free any blocks outside attention window. + return + last_useful_block = last_useful_token // self.block_size + blocks = self.req_to_blocks[request_id] + removed_blocks: list[KVCacheBlock] = [] + for i in range(last_useful_block - 1, -1, -1): + if blocks[i] == self._null_block: + # If the block is already a null block, the blocks before it + # should also have been set to null blocks by the previous calls + # to this function. + break + removed_blocks.append(blocks[i]) + blocks[i] = self._null_block + self.block_pool.free_blocks(removed_blocks) + + def get_last_useful_token(self, num_computed_tokens: int) -> int: + """ + Get the last token (leftmost token) index that is inside attn window. + + Args: + num_computed_tokens: The number of tokens that have been computed. + + Returns: + The last token (leftmost token) index that is inside attn window. + """ + return 0 class FullAttentionManager(SingleTypeKVCacheManager): @@ -284,11 +312,6 @@ def find_longest_cache_hit( computed.pop() return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: - # No need to remove blocks for full attention. - pass - def get_num_common_prefix_blocks(self, request_id: str, num_running_requests: int) -> int: blocks = self.req_to_blocks[request_id] @@ -372,23 +395,19 @@ def find_longest_cache_hit( computed.pop() return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: - # Remove the blocks that are no longer be in the sliding window and - # skipped during the attention computation. - last_useful_token = num_computed_tokens - self.sliding_window + 1 - last_useful_block = last_useful_token // self.block_size - blocks = self.req_to_blocks[request_id] - removed_blocks: list[KVCacheBlock] = [] - for i in range(last_useful_block - 1, -1, -1): - if blocks[i] == self._null_block: - # If the block is already a null block, the blocks before it - # should also have been set to null blocks by the previous calls - # to this function. - break - removed_blocks.append(blocks[i]) - blocks[i] = self._null_block - self.block_pool.free_blocks(removed_blocks) + def get_last_useful_token(self, num_computed_tokens: int) -> int: + """ + Get the last token (leftmost token) index that is inside + sliding window. + + Args: + num_computed_tokens: The number of tokens that have been computed. + + Returns: + The last token (leftmost token) index that is inside + sliding window. + """ + return num_computed_tokens - self.sliding_window + 1 def get_num_common_prefix_blocks(self, request_id: str, num_running_requests: int) -> int: @@ -483,40 +502,22 @@ def find_longest_cache_hit( break return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: - # Remove the blocks that are no longer be in the chunked attention - # window and skipped during the attention computation. - - # [chunk 0][chunk 1]local_attention_start_idx ... current - # we computed previous number of chunks to get the idx of - # current chunk window starting offset, - # e.g. for computed 1024 tokens, the 1024th token (0 indexed) - # is in the second chunk, there are 1 prev chunk, the start idx - # is 1024. for 1023, it will be 0. - num_cached_block = self.num_cached_block.get(request_id, 0) + def get_last_useful_token(self, num_computed_tokens: int) -> int: + """ + Get the last token (leftmost token) index that is inside chunked local + attention window. + + Args: + num_computed_tokens: The number of tokens that have been computed. + + Returns: + The last token (leftmost token) index that is inside chunked local + attention window. + """ local_attention_start_idx = ( num_computed_tokens ) // self.attention_chunk_size * self.attention_chunk_size - first_useful_block_idx = local_attention_start_idx // self.block_size - if num_cached_block > 0: - # Make sure we don't delete the last cached block - first_useful_block_idx = min(first_useful_block_idx, - num_cached_block - 1) - # if block size = 128, 0 -> block 0, 1024 (= 128 * 8) -> - # block 8, 372 (= 128 * 2 + 116) -> block 2 - blocks = self.req_to_blocks[request_id] - removed_blocks: list[KVCacheBlock] = [] - # we need to keep the last block to get the previous hash key - for i in range(first_useful_block_idx - 1, -1, -1): - if blocks[i] == self._null_block: - # If the block is already a null block, the blocks before it - # should also have been set to null blocks by the previous calls - # to this function. - break - removed_blocks.append(blocks[i]) - blocks[i] = self._null_block - self.block_pool.free_blocks(removed_blocks) + return local_attention_start_idx def get_num_common_prefix_blocks(self, request_id: str, num_running_requests: int) -> int: @@ -549,12 +550,6 @@ def find_longest_cache_hit( [] for _ in range(len(kv_cache_group_ids))) return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: - # Each request will always have 1 block at this moment, so no need to - # remove blocks. - pass - def get_num_common_prefix_blocks(self, request_id: str, num_running_requests: int) -> int: return 0 From c33ecd0298b27a98c85bae7aa896483d77805416 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 22 Sep 2025 16:18:41 -0700 Subject: [PATCH 2/7] add comments Signed-off-by: KuntaiDu --- vllm/v1/core/single_type_kv_cache_manager.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index b6476d25bb3f..9d477611d3f1 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -271,6 +271,7 @@ def get_last_useful_token(self, num_computed_tokens: int) -> int: Returns: The last token (leftmost token) index that is inside attn window. """ + # The default behavior is to not remove any blocks. return 0 @@ -642,12 +643,6 @@ def find_longest_cache_hit( raise NotImplementedError( "CrossAttentionManager does not support caching") - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: - # Cross-attention blocks represent encoder states which are needed - # for the entire decoding process, so no blocks should be skipped - pass - spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, From ac0a89598bf8c5126d2f0c35e16081f7648dd14d Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 3 Nov 2025 16:32:41 -0800 Subject: [PATCH 3/7] handle Chen's suggestion Signed-off-by: KuntaiDu --- vllm/v1/core/single_type_kv_cache_manager.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 7e32a044f4c5..5cf4b7641df5 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -248,18 +248,21 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No Remove the blocks that are no longer needed from `blocks` and free the blocks. The removed blocks should be replaced by null_block. Need to be customized for each attention type. + Args: request_id: The request ID. num_computed_tokens: The number of tokens that have been computed. """ # Remove the blocks that are no longer be in the sliding window and # skipped during the attention computation. - last_useful_token = self.get_last_useful_token(num_computed_tokens) - if last_useful_token <= 0: + num_skipped_tokens = self.get_num_skipped_tokens(num_computed_tokens) + if num_skipped_tokens <= 0: # This indicates that ALL tokens are inside attention window. # Thus we do not need to free any blocks outside attention window. + # A typical case is full attention that we never free any token + # before the request is finished. return - last_useful_block = last_useful_token // self.block_size + last_useful_block = num_skipped_tokens // self.block_size blocks = self.req_to_blocks[request_id] removed_blocks: list[KVCacheBlock] = [] for i in range(last_useful_block - 1, -1, -1): @@ -272,15 +275,15 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No blocks[i] = self._null_block self.block_pool.free_blocks(removed_blocks) - def get_last_useful_token(self, num_computed_tokens: int) -> int: + def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: """ - Get the last token (leftmost token) index that is inside attn window. + Get the last token (leftmost token) index that is inside attention window. Args: num_computed_tokens: The number of tokens that have been computed. Returns: - The last token (leftmost token) index that is inside attn window. + The last token (leftmost token) index that is inside attention window. """ # The default behavior is to not remove any blocks. return 0 @@ -414,7 +417,7 @@ def find_longest_cache_hit( computed.pop() return computed_blocks - def get_last_useful_token(self, num_computed_tokens: int) -> int: + def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: """ Get the last token (leftmost token) index that is inside sliding window. @@ -527,7 +530,7 @@ def find_longest_cache_hit( break return computed_blocks - def get_last_useful_token(self, num_computed_tokens: int) -> int: + def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: """ Get the last token (leftmost token) index that is inside chunked local attention window. From 8cd2b21bcefabf1f46d409764ec24f5e56f7f2bc Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 4 Nov 2025 15:09:06 -0800 Subject: [PATCH 4/7] adjust the docstring Signed-off-by: KuntaiDu --- vllm/v1/core/single_type_kv_cache_manager.py | 83 ++++++++++++++------ 1 file changed, 59 insertions(+), 24 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 5cf4b7641df5..0c2f05cf2278 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -245,16 +245,14 @@ def find_longest_cache_hit( def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: """ - Remove the blocks that are no longer needed from `blocks` and free the - blocks. The removed blocks should be replaced by null_block. - Need to be customized for each attention type. + Remove and free the blocks that are no longer needed for attention computation. + The removed blocks should be replaced by null_block. Args: request_id: The request ID. num_computed_tokens: The number of tokens that have been computed. """ - # Remove the blocks that are no longer be in the sliding window and - # skipped during the attention computation. + # Remove the blocks that will be skipped during attention computation. num_skipped_tokens = self.get_num_skipped_tokens(num_computed_tokens) if num_skipped_tokens <= 0: # This indicates that ALL tokens are inside attention window. @@ -262,10 +260,12 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No # A typical case is full attention that we never free any token # before the request is finished. return - last_useful_block = num_skipped_tokens // self.block_size + num_skipped_blocks = num_skipped_tokens // self.block_size blocks = self.req_to_blocks[request_id] removed_blocks: list[KVCacheBlock] = [] - for i in range(last_useful_block - 1, -1, -1): + # Because the block starts from index 0, the num_skipped_block-th block + # corresponds to index num_skipped_blocks - 1. + for i in range(num_skipped_blocks - 1, -1, -1): if blocks[i] == self._null_block: # If the block is already a null block, the blocks before it # should also have been set to null blocks by the previous calls @@ -277,15 +277,15 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: """ - Get the last token (leftmost token) index that is inside attention window. + Get the number of tokens that will be skipped for attention computation. Args: num_computed_tokens: The number of tokens that have been computed. Returns: - The last token (leftmost token) index that is inside attention window. + The number of tokens that will be skipped for attention computation. """ - # The default behavior is to not remove any blocks. + # The default behavior is to not skip any tokens. return 0 @@ -419,15 +419,27 @@ def find_longest_cache_hit( def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: """ - Get the last token (leftmost token) index that is inside - sliding window. + Get the number of tokens that will be skipped for attention computation. + + For sliding window, this corresponds to the tokens that are prior to + the current sliding window. + + For example, if sliding_window is 4 and num_computed_tokens is 7 + (i.e., tokens 0~6 have been computed): + + Tokens: [ 0 1 2 | 3 4 5 6 ] + ^^^^^ ^^^^^^^^^^^^ + skipped sliding window + + The current window contains tokens 3~6. Tokens 0~2 will be skipped for + attention computation since they are outside the sliding window. + Thus, get_num_skipped_tokens(7) == 3. Args: num_computed_tokens: The number of tokens that have been computed. Returns: - The last token (leftmost token) index that is inside - sliding window. + The number of tokens that will be skipped for attention computation. """ return num_computed_tokens - self.sliding_window + 1 @@ -532,22 +544,45 @@ def find_longest_cache_hit( def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: """ - Get the last token (leftmost token) index that is inside chunked local - attention window. + Get the number of tokens that will be skipped for attention computation. + + For chunked local attention, this corresponds to the tokens that are on + the left side of the current chunk. + + Example 1: + # chunk size = 8, num_computed_tokens = 13 + # Tokens: 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 | ... + # | ----- computed -------------| + # ^^ next token to be computed + # [skipped] [current chunk/attended] + # Output: get_num_skipped_tokens(13) == 8 + + Example 2: + # chunk size = 8, num_computed_tokens = 8 + # Tokens: 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 | ... + # | --- computed -| + # ^ next token to be computed + # [skipped] |[current chunk/attended] + # Output: get_num_skipped_tokens(8) == 8 + + Example 3: + # chunk size = 8, num_computed_tokens = 7 + # Tokens: 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 | ... + # |--computed-| + # ^ next token to be computed + # |[current chunk]| + # Output: get_num_skipped_tokens(7) == 0 Args: num_computed_tokens: The number of tokens that have been computed. Returns: - The last token (leftmost token) index that is inside chunked local - attention window. + The number of tokens that will be skipped for attention computation. """ - local_attention_start_idx = ( - (num_computed_tokens) - // self.attention_chunk_size - * self.attention_chunk_size - ) - return local_attention_start_idx + num_skipped_tokens = ( + num_computed_tokens // self.attention_chunk_size + ) * self.attention_chunk_size + return num_skipped_tokens def get_num_common_prefix_blocks(self, running_request_id: str) -> int: """ From f91f8282a6209b82abc90641e6ccb2936a170b55 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 4 Nov 2025 15:19:32 -0800 Subject: [PATCH 5/7] adjust the docstring Signed-off-by: KuntaiDu --- vllm/v1/core/single_type_kv_cache_manager.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 0c2f05cf2278..c08326558cfe 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -424,16 +424,18 @@ def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: For sliding window, this corresponds to the tokens that are prior to the current sliding window. - For example, if sliding_window is 4 and num_computed_tokens is 7 - (i.e., tokens 0~6 have been computed): + Example: + sliding_window=4, num_computed_tokens=7 - Tokens: [ 0 1 2 | 3 4 5 6 ] - ^^^^^ ^^^^^^^^^^^^ - skipped sliding window + Tokens: [ 0 1 2 3 4 5 6 7 ] + | ---- computed -----| + ^ next token to be computed + |-----------| sliding window for next token + |--skipped--| - The current window contains tokens 3~6. Tokens 0~2 will be skipped for + The current window contains tokens 4~7. Tokens 0~3 will be skipped for attention computation since they are outside the sliding window. - Thus, get_num_skipped_tokens(7) == 3. + Thus, get_num_skipped_tokens(7) == 4. Args: num_computed_tokens: The number of tokens that have been computed. From 8b0592ed66794263c974d47326e16c6a35aeb232 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 4 Nov 2025 15:20:53 -0800 Subject: [PATCH 6/7] mention num_skipped_tokens in docstring Signed-off-by: KuntaiDu --- vllm/v1/core/single_type_kv_cache_manager.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index c08326558cfe..0490aa8ba3e6 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -248,6 +248,9 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No Remove and free the blocks that are no longer needed for attention computation. The removed blocks should be replaced by null_block. + This function depends on `get_num_skipped_tokens`, which need to be implemented + differently for each attention type. + Args: request_id: The request ID. num_computed_tokens: The number of tokens that have been computed. From 1207f5e60197d4f79e52fd5e9567da2cd05ffeb5 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 5 Nov 2025 00:10:09 -0800 Subject: [PATCH 7/7] adjust figure Signed-off-by: KuntaiDu --- vllm/v1/core/single_type_kv_cache_manager.py | 50 +++++++++++--------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 0490aa8ba3e6..14ac83028ee4 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -430,11 +430,11 @@ def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: Example: sliding_window=4, num_computed_tokens=7 - Tokens: [ 0 1 2 3 4 5 6 7 ] - | ---- computed -----| - ^ next token to be computed - |-----------| sliding window for next token - |--skipped--| + Tokens: [ 0 1 2 3 4 5 6 7 ] + | ---- computed -----| + ^ next token to be computed + |-----------| sliding window for next token + |--skipped---| The current window contains tokens 4~7. Tokens 0~3 will be skipped for attention computation since they are outside the sliding window. @@ -555,28 +555,32 @@ def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: the left side of the current chunk. Example 1: - # chunk size = 8, num_computed_tokens = 13 - # Tokens: 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 | ... - # | ----- computed -------------| - # ^^ next token to be computed - # [skipped] [current chunk/attended] - # Output: get_num_skipped_tokens(13) == 8 + chunk size = 8, num_computed_tokens = 13 + Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ... + | ----- computed ---------------| + ^^ next token to be computed + |----------------| <-- attention window for + next token + |--- skipped -----| + Output: get_num_skipped_tokens(13) == 8 Example 2: - # chunk size = 8, num_computed_tokens = 8 - # Tokens: 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 | ... - # | --- computed -| - # ^ next token to be computed - # [skipped] |[current chunk/attended] - # Output: get_num_skipped_tokens(8) == 8 + chunk size = 8, num_computed_tokens = 8 + Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ... + | --- computed ---| + ^ next token to be computed + |--| <-- attention window for next token + | --- skipped ----| + Output: get_num_skipped_tokens(8) == 8 Example 3: - # chunk size = 8, num_computed_tokens = 7 - # Tokens: 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 | ... - # |--computed-| - # ^ next token to be computed - # |[current chunk]| - # Output: get_num_skipped_tokens(7) == 0 + chunk size = 8, num_computed_tokens = 7 + Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ... + |---computed---| + ^ next token to be computed + |-----------------| <-- attention window for next token + no token should be skipped. + Output: get_num_skipped_tokens(7) == 0 Args: num_computed_tokens: The number of tokens that have been computed.