diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index d598d12571f1..8956393c0bfb 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1,12 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 """Compare the with and without prefix caching.""" +from typing import List + import pytest from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.utils import cdiv +from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request -from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens +from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, + hash_block_tokens) def make_request(request_id, @@ -62,14 +66,14 @@ def test_prefill(): for block_id in (0, 1, 2): block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16]) block_hash = hash_block_tokens(parent_block_hash, block_tokens) - assert manager.block_pool[block_id].block_hash == block_hash - assert manager.block_pool[block_id].ref_cnt == 1 + assert manager.block_pool.blocks[block_id].block_hash == block_hash + assert manager.block_pool.blocks[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value # Check partial/preallocated block metadata for block_id in (3, 4): - assert manager.block_pool[block_id].block_hash is None - assert manager.block_pool[block_id].ref_cnt == 1 + assert manager.block_pool.blocks[block_id].block_hash is None + assert manager.block_pool.blocks[block_id].ref_cnt == 1 # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) @@ -86,20 +90,21 @@ def test_prefill(): assert block.ref_cnt == 2 # At this point, we should have 3 free blocks left. - assert manager.free_block_queue.num_free_blocks == 3 + assert manager.block_pool.free_block_queue.num_free_blocks == 3 manager.free(req0) manager.free(req1) # All blocks should be available. - assert manager.free_block_queue.num_free_blocks == 10 + assert manager.block_pool.free_block_queue.num_free_blocks == 10 # The order should be # [unallocated (7, 8, 9)] # [unique_req0 (4, 3)] # [unique_req1 (6, 5)] # [common (2, 1, 0)] assert [ - b.block_id for b in manager.free_block_queue.get_all_free_blocks() + b.block_id + for b in manager.block_pool.free_block_queue.get_all_free_blocks() ] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0] # Cache hit in the common prefix when the original block is already free. @@ -116,12 +121,14 @@ def test_prefill(): # Although we only have 5 free blocks, we have 8 blocks in # the free block queue due to lazy removal. - assert manager.free_block_queue.num_free_blocks == 5 + assert manager.block_pool.free_block_queue.num_free_blocks == 5 assert all([ - b.ref_cnt == 0 for b in manager.free_block_queue.get_all_free_blocks() + b.ref_cnt == 0 + for b in manager.block_pool.free_block_queue.get_all_free_blocks() ]) - assert len([b - for b in manager.free_block_queue.get_all_free_blocks()]) == 5 + assert len([ + b for b in manager.block_pool.free_block_queue.get_all_free_blocks() + ]) == 5 manager.free(req2) @@ -133,9 +140,9 @@ def test_prefill(): blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks) # This block ID order also checks the eviction order. assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0] - assert manager.free_block_queue.num_free_blocks == 0 - assert manager.free_block_queue.free_list_head is None - assert manager.free_block_queue.free_list_tail is None + assert manager.block_pool.free_block_queue.num_free_blocks == 0 + assert manager.block_pool.free_block_queue.free_list_head is None + assert manager.block_pool.free_block_queue.free_list_tail is None def test_decode(): @@ -219,13 +226,14 @@ def test_evict(): assert len(blocks) == 3 # 3 full blocks last_token_id += 3 * 16 - assert manager.free_block_queue.num_free_blocks == 0 + assert manager.block_pool.free_block_queue.num_free_blocks == 0 manager.free(req0) manager.free(req1) - assert manager.free_block_queue.num_free_blocks == 10 + assert manager.block_pool.free_block_queue.num_free_blocks == 10 assert [ - b.block_id for b in manager.free_block_queue.get_all_free_blocks() + b.block_id + for b in manager.block_pool.free_block_queue.get_all_free_blocks() ] == [6, 5, 4, 3, 2, 1, 0, 9, 8, 7] # Touch the first 2 blocks. @@ -235,7 +243,7 @@ def test_evict(): assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, computed_blocks) assert [b.block_id for b in blocks] == [6, 5] - assert manager.free_block_queue.num_free_blocks == 6 + assert manager.block_pool.free_block_queue.num_free_blocks == 6 def test_hash_block_correct_reuse(): @@ -274,7 +282,7 @@ def test_hash_block_correct_reuse(): blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks) assert len(blocks) == 1 - assert manager.block_pool[blocks[0].block_id].block_hash is None + assert manager.block_pool.blocks[blocks[0].block_id].block_hash is None def test_computed_blocks_not_evicted(): @@ -413,13 +421,9 @@ def test_cache_blocks(): function of KVCacheManager. """ block_size = 4 - manager = KVCacheManager( - block_size=block_size, + block_pool = BlockPool( num_gpu_blocks=5, - max_model_len=8192, - sliding_window=None, enable_caching=True, - num_preallocate_tokens=0, ) # Req: # Block 0: [0, 1, 2, 3] @@ -430,26 +434,31 @@ def test_cache_blocks(): # Test that blocks are cached correctly for 2 full blocks from the start. blocks = [KVCacheBlock(block_id=i) for i in range(2)] + block_hashes: List[BlockHashType] = [] - manager._cache_full_blocks( + block_pool.cache_full_blocks( request=req, - blk_start_idx=0, - full_blocks=blocks, - prev_block=None, + blocks=blocks, + block_hashes=block_hashes, + num_cached_blocks=0, + num_full_blocks=2, + block_size=block_size, ) - assert len(manager.cached_block_hash_to_block) == 2 + assert len(block_pool.cached_block_hash_to_block) == 2 assert all([block.block_hash is not None for block in blocks]) # Test that blocks that don't start from the beginning are cached correctly. - blocks = [KVCacheBlock(block_id=2)] - manager._cache_full_blocks( + blocks += [KVCacheBlock(block_id=2)] + block_pool.cache_full_blocks( request=req, - blk_start_idx=2, - full_blocks=blocks, - prev_block=None, + blocks=blocks, + block_hashes=block_hashes, + num_cached_blocks=2, + num_full_blocks=3, + block_size=block_size, ) - assert len(manager.cached_block_hash_to_block) == 3 + assert len(block_pool.cached_block_hash_to_block) == 3 assert blocks[0].block_hash is not None @@ -580,7 +589,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, # but it cannot be allocated due to insufficient free blocks (2). # In this case, the ref_cnt of the computed blocks should not be changed. - assert manager.free_block_queue.num_free_blocks == 5 + assert manager.block_pool.free_block_queue.num_free_blocks == 5 req3 = make_request("3", common_token_ids * 3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert computed_blocks == block_part1 @@ -621,12 +630,12 @@ def test_reset_prefix_cache(): # Failed to reset prefix cache because some blocks are not freed yet. assert not manager.reset_prefix_cache() - assert manager.cached_block_hash_to_block + assert manager.block_pool.cached_block_hash_to_block # Free the blocks. manager.free(req0) manager.free(req1) assert manager.reset_prefix_cache() - assert not manager.cached_block_hash_to_block - assert all([blk.block_hash is None for blk in manager.block_pool]) + assert not manager.block_pool.cached_block_hash_to_block + assert all([blk.block_hash is None for blk in manager.block_pool.blocks]) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py new file mode 100644 index 000000000000..5ef495c7eed8 --- /dev/null +++ b/vllm/v1/core/block_pool.py @@ -0,0 +1,285 @@ +# SPDX-License-Identifier: Apache-2.0 +from collections import defaultdict +from typing import Dict, Iterable, List, Optional + +from vllm.logger import init_logger +from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, + KVCacheBlock, + generate_block_hash_extra_keys, + hash_block_tokens) +from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class BlockPool: + """BlockPool that manages KVCacheBlocks. + It provides methods to allocate, free and cache the kv cache blocks. The + free_block_queue stores the free blocks in eviction order to enable + allocation, free, and cache eviction. The cached_block_hash_to_block + maps between block hash and cached block to support finding cached blocks + by their block hash. + + Args: + num_gpu_blocks: The number of blocks in the pool. + enable_caching: Whether to enable prefix caching. + """ + + def __init__(self, num_gpu_blocks: int, enable_caching: bool): + self.num_gpu_blocks = num_gpu_blocks + self.enable_caching = enable_caching + # All kv-cache blocks. + self.blocks: List[KVCacheBlock] = [ + KVCacheBlock(idx) for idx in range(num_gpu_blocks) + ] + # Free block queue that constructs and manipulates a doubly linked + # list of free blocks (including eviction candidates when caching is + # enabled). + self.free_block_queue = FreeKVCacheBlockQueue(self.blocks) + + # {block_hash: {block ID: block}}. A cached block is + # a full block with a block hash that can be used for prefix caching. + # The cached block may be used by running requests or in the + # free_block_queue that could potentially be evicted. + # NOTE: We currently don't de-duplicate the blocks in the cache, + # meaning that if a block becomes full and is cached, we don't check + # if there is already an identical block in the cache. This is because + # we want to make sure the allocated block IDs won't change so that + # block tables are append-only. + self.cached_block_hash_to_block: Dict[BlockHashType, Dict[ + int, KVCacheBlock]] = defaultdict(dict) + + def get_cached_block(self, + block_hash: BlockHashType) -> Optional[KVCacheBlock]: + """Get a cached block by the block hash, or None if cache miss. + If there are duplicated blocks, we return the first block in the cache. + + Args: + block_hash: The hash value of the block. + + Returns: + The cached block if it exists, or None. + """ + if block_hash in self.cached_block_hash_to_block: + first_block_id = list( + self.cached_block_hash_to_block[block_hash].keys())[0] + return self.cached_block_hash_to_block[block_hash][first_block_id] + return None + + def cache_full_blocks( + self, + request: Request, + blocks: List[KVCacheBlock], + block_hashes: List[BlockHashType], + num_cached_blocks: int, + num_full_blocks: int, + block_size: int, + ) -> None: + """Cache a list of full blocks for prefix caching. + This function takes a list of blocks that will have their block hash + metadata to be updated and cached. Given a request, it computes the + block hashes for the blocks starting from `num_cached_blocks` to + `num_full_blocks`, updating the metadata for each block + and caching them in the `cached_block_hash_to_block`. + + Args: + request: The request to cache the blocks. + blocks: All blocks in the request. + block_hashes: Block hashes of the blocks in the request. Note that + this list may be shorter than the blocks list. In this case the + missed block hash will be computed in this function. + num_cached_blocks: The number of blocks that are already cached. + num_full_blocks: The number of blocks that are full and should + be cached after this function. + block_size: Number of tokens in each block. + """ + if num_cached_blocks == num_full_blocks: + return + new_full_blocks = blocks[num_cached_blocks:num_full_blocks] + assert len(block_hashes) >= num_cached_blocks + new_block_hashes = block_hashes[num_cached_blocks:] + + # Update the new blocks with the block hashes through the chain. + if num_cached_blocks == 0: + prev_block_hash_value = None + else: + prev_block = blocks[num_cached_blocks - 1] + assert prev_block.block_hash is not None + prev_block_hash_value = prev_block.block_hash.hash_value + + # Find the first uncached block. + # FIXME: num_cached_blocks should be corrected by the caller + # so this should never happen. + offset = 0 + for blk in new_full_blocks: + if blk.block_hash is None: + break + else: + prev_block_hash_value = blk.block_hash.hash_value + offset += 1 + else: + # All blocks are cached. + return + + for i, blk in enumerate(new_full_blocks[offset:]): + blk_idx = num_cached_blocks + offset + i + assert blk.block_hash is None + + if i + offset < len(new_block_hashes): + # The block hash may already be computed in + # "get_computed_blocks" if the tokens are not generated by + # this request (either the prompt tokens or the previously + # generated tokens with preemption). In this case we simply + # reuse the block hash. + block_hash = new_block_hashes[i + offset] + else: + # Otherwise compute the block hash and cache it in the request + # in case it will be preempted in the future. + start_token_idx = blk_idx * block_size + end_token_idx = (blk_idx + 1) * block_size + block_tokens = request.all_token_ids[ + start_token_idx:end_token_idx] + assert len(block_tokens) == block_size, ( + f"Expected {block_size} tokens, got " + f"{len(block_tokens)} at {blk_idx}th block for request " + f"{request.request_id}({request})") + + # Generate extra keys for multi-modal inputs. Note that since + # we reach to this branch only when the block is completed with + # generated tokens, we only need to consider the last mm input. + extra_keys, _ = generate_block_hash_extra_keys( + request, start_token_idx, end_token_idx, -1) + + # Compute the hash of the current block. + block_hash = hash_block_tokens(prev_block_hash_value, + block_tokens, extra_keys) + block_hashes.append(block_hash) + + # Update and added the full block to the cache. + blk.block_hash = block_hash + self.cached_block_hash_to_block[block_hash][blk.block_id] = blk + prev_block_hash_value = block_hash.hash_value + + def get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: + """Get new blocks from the free block pool. + + Note that we do not check block cache in this function. + + Args: + num_blocks: The number of blocks to allocate. + + Returns: + A list of new block. + """ + if num_blocks > self.get_num_free_blocks(): + raise ValueError( + f"Cannot get {num_blocks} free blocks from the pool") + + ret: List[KVCacheBlock] = [] + idx = 0 + while idx < num_blocks: + # First allocate blocks. + curr_block = self.free_block_queue.popleft() + assert curr_block.ref_cnt == 0 + + # If the block is cached, evict it. + if self.enable_caching: + self._maybe_evict_cached_block(curr_block) + + curr_block.incr_ref() + ret.append(curr_block) + idx += 1 + + return ret + + def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: + """ + If a block is cached in `cached_block_hash_to_block`, we reset its hash + metadata and evict it from the cache. + + Args: + block: The block to evict. + + Returns: + True if the block is evicted, False otherwise. + """ + block_hash = block.block_hash + if block_hash and block_hash in self.cached_block_hash_to_block: + block.reset_hash() + del self.cached_block_hash_to_block[block_hash][block.block_id] + + if len(self.cached_block_hash_to_block[block_hash]) == 0: + del self.cached_block_hash_to_block[block_hash] + + return True + return False + + def touch(self, blocks: List[KVCacheBlock]) -> None: + """Touch a block increases its reference count by 1, and may remove + the block from the free queue. This is used when a block is hit by + another request with the same prefix. + + Args: + blocks: A list of blocks to touch. + """ + for block in blocks: + # ref_cnt=0 means this block is in the free list (i.e. eviction + # candidate), so remove it. + if block.ref_cnt == 0: + self.free_block_queue.remove(block) + block.incr_ref() + + def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: + """Free a list of blocks. The blocks should be ordered by their + eviction priority, where the first block will be evicted first. + + Args: + ordered_blocks: A list of blocks to free ordered by their eviction + priority. + """ + for block in ordered_blocks: + block.decr_ref() + if block.ref_cnt == 0: + self.free_block_queue.append(block) + + def reset_prefix_cache(self) -> bool: + """Reset prefix cache. This function may be used in RLHF + flows to invalid prefix caching after the weights are updated, + or used for resetting prefix caching status for benchmarking. + + Returns: + bool: True if the prefix cache is successfully reset, + False otherwise. + """ + num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks()) + if num_used_blocks > 0: + logger.warning( + "Failed to reset prefix cache because some " + "blocks (%d) are not freed yet", num_used_blocks) + return False + + # Remove all hashes so that no new blocks will hit. + self.cached_block_hash_to_block = defaultdict(dict) + + # Remove all hashes from all blocks. + for block in self.blocks: + block.reset_hash() + + logger.info("Successfully reset prefix cache") + return True + + def get_num_free_blocks(self) -> int: + """Get the number of free blocks in the pool. + + Returns: + The number of free blocks. + """ + return self.free_block_queue.num_free_blocks + + def get_usage(self) -> float: + """Get the KV cache usage. + + Returns: + The KV cache usage (between 0.0 and 1.0). + """ + return 1.0 - (self.get_num_free_blocks() / self.num_gpu_blocks) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 017e625dcdba..fc7bfa0eff57 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -5,10 +5,8 @@ from vllm.logger import init_logger from vllm.utils import cdiv -from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, - generate_block_hash_extra_keys, - hash_block_tokens, +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, hash_request_tokens) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus @@ -49,26 +47,7 @@ def __init__( self.num_preallocate_tokens = num_preallocate_tokens self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size) - # A Block pool of all kv-cache blocks. - self.block_pool: List[KVCacheBlock] = [ - KVCacheBlock(idx) for idx in range(num_gpu_blocks) - ] - # Free block queue that constructs and manipulates a doubly linked - # list of free blocks (including eviction candidates when caching is - # enabled). - self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool) - - # {block_hash: {block ID: block}}. A cached block is - # a full block with a block hash that can be used for prefix caching. - # The cached block may be used by running requests or in the - # free_block_queue that could potentially be evicted. - # NOTE: We currently don't de-duplicate the blocks in the cache, - # meaning that if a block becomes full and is cached, we don't check - # if there is already an identical block in the cache. This is because - # we want to make sure the allocated block IDs won't change so that - # block tables are append-only. - self.cached_block_hash_to_block: Dict[BlockHashType, Dict[ - int, KVCacheBlock]] = defaultdict(dict) + self.block_pool = BlockPool(num_gpu_blocks, enable_caching) # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request @@ -96,8 +75,7 @@ def usage(self) -> float: Returns: The KV cache usage (between 0.0 and 1.0). """ - return 1.0 - (self.free_block_queue.num_free_blocks / - self.num_gpu_blocks) + return self.block_pool.get_usage() def make_prefix_cache_stats(self) -> PrefixCacheStats: """Get (and reset) the prefix cache stats. @@ -139,7 +117,7 @@ def get_computed_blocks( # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. - if cached_block := self._get_cached_block(block_hash): + if cached_block := self.block_pool.get_cached_block(block_hash): computed_blocks.append(cached_block) else: break @@ -204,14 +182,14 @@ def allocate_slots( # when allocating this request. num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks if blk.ref_cnt == 0) - if (num_new_blocks > self.free_block_queue.num_free_blocks - + if (num_new_blocks > self.block_pool.get_num_free_blocks() - num_evictable_computed_blocks): # Cannot allocate new blocks return None # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: - self._touch(new_computed_blocks) + self.block_pool.touch(new_computed_blocks) else: assert not new_computed_blocks, ( "Computed blocks should be empty when " @@ -231,7 +209,7 @@ def allocate_slots( # preallocated blocks. num_new_blocks = min( num_new_blocks + self.num_preallocate_blocks, - self.free_block_queue.num_free_blocks, + self.block_pool.get_num_free_blocks(), # Should not exceed the maximum number of blocks per request. # This is especially because the block table has the shape # [..., max_num_blocks_per_req]. @@ -240,29 +218,30 @@ def allocate_slots( assert num_new_blocks > 0 # Concatenate the computed block IDs and the new block IDs. - new_blocks = self._get_new_blocks(num_new_blocks) + new_blocks = self.block_pool.get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) if not self.enable_caching: return new_blocks + # FIXME: `num_cached_blocks` is not correct when the prefix cache + # of a new request is hit. num_cached_blocks = self.num_cached_block[request.request_id] # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( request.spec_token_ids)) // self.block_size - new_full_blocks = req_blocks[ - num_cached_blocks:num_full_blocks_after_append] - - if new_full_blocks: - self._cache_full_blocks( - request=request, - blk_start_idx=num_cached_blocks, - # The new full blocks are the full blocks that are not computed. - full_blocks=new_full_blocks, - prev_block=(req_blocks[num_cached_blocks - - 1] if num_cached_blocks > 0 else None)) + + self.block_pool.cache_full_blocks( + request=request, + blocks=req_blocks, + block_hashes=self.req_to_block_hashes[request.request_id], + num_cached_blocks=num_cached_blocks, + num_full_blocks=num_full_blocks_after_append, + block_size=self.block_size, + ) + self.num_cached_block[ request.request_id] = num_full_blocks_after_append return new_blocks @@ -283,11 +262,7 @@ def free(self, request: Request) -> None: # freed first. ordered_blocks = reversed(blocks) - for block in ordered_blocks: - block.decr_ref() - if block.ref_cnt == 0: - self.free_block_queue.append(block) - + self.block_pool.free_blocks(ordered_blocks) self.num_cached_block.pop(request.request_id, None) def reset_prefix_cache(self) -> bool: @@ -299,25 +274,10 @@ def reset_prefix_cache(self) -> bool: bool: True if the prefix cache is successfully reset, False otherwise. """ - num_used_blocks = (self.num_gpu_blocks - - self.free_block_queue.num_free_blocks) - if num_used_blocks > 0: - logger.warning( - "Failed to reset prefix cache because some " - "blocks (%d) are not freed yet", num_used_blocks) - return False - - # Remove all hashes so that no new blocks will hit. - self.cached_block_hash_to_block = defaultdict(dict) - - # Remove all hashes from all blocks. - for block in self.block_pool: - block.reset_hash() - - self.prefix_cache_stats.reset = True - - logger.info("Successfully reset prefix cache") - return True + if self.block_pool.reset_prefix_cache(): + self.prefix_cache_stats.reset = True + return True + return False def get_num_common_prefix_blocks( self, @@ -367,177 +327,6 @@ def get_num_common_prefix_blocks( break return num_common_blocks - def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: - """Get new blocks from the free block pool. - - Note that we do not check block cache in this function. - - Args: - num_blocks: The number of blocks to allocate. - - Returns: - A list of new block. - """ - if num_blocks > self.free_block_queue.num_free_blocks: - raise ValueError( - f"Cannot get {num_blocks} free blocks from the pool") - - ret: List[KVCacheBlock] = [] - idx = 0 - while idx < num_blocks: - # First allocate blocks. - curr_block = self.free_block_queue.popleft() - assert curr_block.ref_cnt == 0 - - # If the block is cached, evict it. - if self.enable_caching: - self._maybe_evict_cached_block(curr_block) - - curr_block.incr_ref() - ret.append(curr_block) - idx += 1 - - return ret - - def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: - """ - If a block is cached in `cached_block_hash_to_block`, we reset its hash - metadata and evict it from the cache. - - Args: - block: The block to evict. - - Returns: - True if the block is evicted, False otherwise. - """ - block_hash = block.block_hash - if block_hash and block_hash in self.cached_block_hash_to_block: - block.reset_hash() - del self.cached_block_hash_to_block[block_hash][block.block_id] - - if len(self.cached_block_hash_to_block[block_hash]) == 0: - del self.cached_block_hash_to_block[block_hash] - - return True - return False - - def _get_cached_block(self, - block_hash: BlockHashType) -> Optional[KVCacheBlock]: - """Get a cached block by the block hash, or None if cache miss. - If there are duplicated blocks, we return the first block in the cache. - - Args: - block_hash: The hash value of the block. - - Returns: - The cached block if it exists, or None. - """ - if block_hash in self.cached_block_hash_to_block: - first_block_id = list( - self.cached_block_hash_to_block[block_hash].keys())[0] - return self.cached_block_hash_to_block[block_hash][first_block_id] - return None - - def _touch(self, blocks: List[KVCacheBlock]) -> None: - """Touch a block increases its reference count by 1, and may remove - the block from the free queue. This is used when a block is hit by - another request with the same prefix. - - Args: - blocks: A list of blocks to touch. - """ - for block in blocks: - # ref_cnt=0 means this block is in the free list (i.e. eviction - # candidate), so remove it. - if block.ref_cnt == 0: - self.free_block_queue.remove(block) - block.incr_ref() - - def _cache_full_blocks( - self, - request: Request, - blk_start_idx: int, - full_blocks: List[KVCacheBlock], - prev_block: Optional[KVCacheBlock], - ) -> None: - """Cache a list of full blocks for prefix caching. - - This function takes a list of blocks that will have their block hash - metadata to be updated and cached. Given a request, it computes the - block hashes for the blocks starting from `blk_start_idx` to the end - of the request's full blocks, updating the metadata for each block - and caching them in the `cached_block_hash_to_block`. - - Args: - request: The request to cache the blocks. - blk_start_idx: The index of the first block in the request's blocks - to cache. - full_blocks: The list of blocks to update hash metadata. - prev_block: The previous block in the chain. - """ - block_hashes = self.req_to_block_hashes[request.request_id] - num_cached_block_hashes = len(block_hashes) - - # Update the new blocks with the block hashes through the chain. - prev_block_hash_value = None - if prev_block is not None: - # Previous block must have a block hash because it must be - # a full, cached block. - assert prev_block.block_hash is not None - prev_block_hash_value = prev_block.block_hash.hash_value - - # Find the first uncached block. This case should only happen when - # speculative decoding is used. - offset = 0 - for blk in full_blocks: - if blk.block_hash is None: - break - else: - prev_block_hash_value = blk.block_hash.hash_value - offset += 1 - else: - # All blocks are cached. - return - - for i, blk in enumerate(full_blocks[offset:]): - blk_idx = blk_start_idx + offset + i - assert blk.block_hash is None - - if blk_idx < num_cached_block_hashes: - # The block hash may already be computed in - # "get_computed_blocks" if the tokens are not generated by - # this request (either the prompt tokens or the previously - # generated tokens with preemption). In this case we simply - # reuse the block hash. - block_hash = block_hashes[blk_idx] - else: - # Otherwise compute the block hash and cache it in the request - # in case it will be preempted in the future. - start_token_idx = blk_idx * self.block_size - end_token_idx = (blk_idx + 1) * self.block_size - block_tokens = request.all_token_ids[ - start_token_idx:end_token_idx] - assert len(block_tokens) == self.block_size, ( - f"Expected {self.block_size} tokens, got " - f"{len(block_tokens)} at {blk_idx}th block for request " - f"{request.request_id}({request})") - - # Generate extra keys for multi-modal inputs. Note that since - # we reach to this branch only when the block is completed with - # generated tokens, we only need to consider the last mm input. - extra_keys, _ = generate_block_hash_extra_keys( - request, start_token_idx, end_token_idx, -1) - - # Compute the hash of the current block. - block_hash = hash_block_tokens(prev_block_hash_value, - block_tokens, extra_keys) - block_hashes.append(block_hash) - - # Update and added the full block to the cache. - blk.block_hash = block_hash - self.cached_block_hash_to_block[block_hash][blk.block_id] = blk - prev_block_hash_value = block_hash.hash_value - def free_block_hashes(self, request: Request) -> None: """Discard the block hashes for the request.