diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 74f3f7852c9a..25bf5bc91e63 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -39,7 +39,7 @@ def __init__(self, num_gpu_blocks: int, enable_caching: bool): # enabled). self.free_block_queue = FreeKVCacheBlockQueue(self.blocks) - # {block_hash: {block ID: block}}. A cached block is + # {block_hash: {group ID: {block ID: block}}}. A cached block is # a full block with a block hash that can be used for prefix caching. # The cached block may be used by running requests or in the # free_block_queue that could potentially be evicted. @@ -48,16 +48,19 @@ def __init__(self, num_gpu_blocks: int, enable_caching: bool): # if there is already an identical block in the cache. This is because # we want to make sure the allocated block IDs won't change so that # block tables are append-only. - self.cached_block_hash_to_block: dict[BlockHashType, dict[ - int, KVCacheBlock]] = defaultdict(dict) + self.cached_block_hash_to_block: dict[BlockHashType, dict[int, dict[ + int, KVCacheBlock]]] = defaultdict(dict) # To represent a placeholder block with block_id=0. # The ref_cnt of null_block is not maintained, needs special care to # avoid freeing it. self.null_block = self.free_block_queue.popleft() + self.null_block.is_null = True - def get_cached_block(self, - block_hash: BlockHashType) -> Optional[KVCacheBlock]: + def get_cached_block( + self, + block_hash: BlockHashType, + ) -> Optional[dict[int, KVCacheBlock]]: """Get a cached block by the block hash, or None if cache miss. If there are duplicated blocks, we return the first block in the cache. @@ -70,8 +73,10 @@ def get_cached_block(self, cached_blocks = self.cached_block_hash_to_block.get(block_hash) if not cached_blocks: return None - first_block_id = next(iter(cached_blocks)) - return cached_blocks[first_block_id] + return { + group_id: next(iter(blocks)) + for group_id, blocks in cached_blocks.items() if blocks + } def cache_full_blocks( self, diff --git a/vllm/v1/core/hybrid_allocator.py b/vllm/v1/core/hybrid_allocator.py new file mode 100644 index 000000000000..f6102a17669d --- /dev/null +++ b/vllm/v1/core/hybrid_allocator.py @@ -0,0 +1,364 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import Any, Optional + +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock +from vllm.v1.core.specialized_manager import (FullAttentionAllocator, + SlidingWindowAllocator, + SpecializedAllocator, + get_specialized_allocator) +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.request import Request + + +class HybridMemoryAllocator(ABC): + + def __init__(self, block_pool: BlockPool): + self.block_pool = block_pool + # req_id -> group_id -> block_hashes + # This is to avoid recomputing the block hashes for each call of + # `get_block_hashes`. + # NOTE: These entries must be freed when requests are finished + # to prevent memory leaks. + self.req_to_block_hashes: dict[str, list[list[BlockHashType]]] = {} + + def get_block_hashes( + self, + request: Request, + hash_fn: Any, + ) -> list[list[BlockHashType]]: + # The block hashes for the request may already be computed + # if the scheduler has tried to schedule the request before. + block_hashes = self.req_to_block_hashes.get(request.request_id) + if block_hashes is None: + block_hashes = self._get_block_hashes(request, hash_fn) + self.req_to_block_hashes[request.request_id] = block_hashes + return block_hashes + + @abstractmethod + def _get_block_hashes( + self, + request: Request, + hash_fn: Any, + ) -> list[list[BlockHashType]]: + raise NotImplementedError + + @abstractmethod + def find_longest_cache_hit( + self, + block_hashes: list[list[BlockHashType]], + num_tokens: int, + ) -> tuple[list[list[KVCacheBlock]], int]: + raise NotImplementedError + + @abstractmethod + def remove_skipped_blocks( + self, + blocks: list[list[KVCacheBlock]], + num_computed_tokens: int, + ) -> Iterable[KVCacheBlock]: + raise NotImplementedError + + def allocate_blocks( + self, + total_num_tokens: int, + num_computed_tokens: int, + new_computed_blocks: list[list[KVCacheBlock]], + allocated_blocks: list[list[KVCacheBlock]], + ) -> Optional[list[list[KVCacheBlock]]]: + num_new_blocks = self._get_num_new_blocks( + total_num_tokens, + num_computed_tokens, + new_computed_blocks, + allocated_blocks, + ) + total_num_new_blocks = sum(num_new_blocks) + if total_num_new_blocks <= 0: + # No new block is needed. + return [] + flattened_new_computed_blocks = sum(new_computed_blocks, []) + + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it cannot be counted as a free block + # when allocating this request. + num_evictable_computed_blocks = sum( + 1 for blk in flattened_new_computed_blocks if blk.ref_cnt == 0) + if (total_num_new_blocks > self.block_pool.get_num_free_blocks() - + num_evictable_computed_blocks): + # Cannot allocate new blocks + return None + + # Touch the computed blocks to make sure they won't be evicted. + if flattened_new_computed_blocks: + self.block_pool.touch(flattened_new_computed_blocks) + + total_new_blocks = self.block_pool.get_new_blocks(total_num_new_blocks) + new_blocks: list[list[KVCacheBlock]] = [] + start = 0 + for n in num_new_blocks: + end = start + n + new_blocks.append(total_new_blocks[start:end]) + start = end + return new_blocks + + @abstractmethod + def _get_num_new_blocks( + self, + total_num_tokens: int, + num_computed_tokens: int, + new_computed_blocks: list[list[KVCacheBlock]], + allocated_blocks: list[list[KVCacheBlock]], + ) -> list[int]: + raise NotImplementedError + + @abstractmethod + def cache_blocks( + self, + request: Request, + blocks: list[list[KVCacheBlock]], + num_computed_tokens: int, + num_new_tokens: int, + num_cached_blocks: list[int], + hash_fn: Any, + ) -> list[int]: + raise NotImplementedError + + @abstractmethod + def sort_by_eviction_order( + self, + blocks: list[list[KVCacheBlock]], + ) -> Iterable[KVCacheBlock]: + raise NotImplementedError + + +class SingleMemoryAllocator(HybridMemoryAllocator): + """Memory allocator for a single attention type. + + For example, models with full attention only (e.g., Llama 3, DeepSeek) and + models with sliding window attention only (e.g., an old version of Mistral) + use this allocator. + """ + + def __init__( + self, + block_pool: BlockPool, + allocator: SpecializedAllocator, + ): + super().__init__(block_pool) + self.allocator = allocator + self.block_size = allocator.block_size + self.group_ids = (0, ) + + def _get_block_hashes( + self, + request: Request, + hash_fn: Any, + ) -> list[list[BlockHashType]]: + return [self.allocator.get_block_hashes(request, hash_fn)] + + def find_longest_cache_hit( + self, + block_hashes: list[list[BlockHashType]], + num_tokens: int, + ) -> tuple[list[list[KVCacheBlock]], int]: + block_hashes = block_hashes[0] + if len(block_hashes) * self.block_size == num_tokens: + block_hashes = block_hashes[:-1] + blocks, num_computed_tokens = self.allocator.find_longest_cache_hit( + block_hashes, self.group_ids) + return [blocks[0]], num_computed_tokens + + def remove_skipped_blocks( + self, + blocks: dict[int, list[KVCacheBlock]], + num_computed_tokens: int, + ) -> Iterable[KVCacheBlock]: + return self.allocator.remove_skipped_blocks(blocks, self.group_ids, + num_computed_tokens) + + def _get_num_new_blocks( + self, + total_num_tokens: int, + num_computed_tokens: int, + new_computed_blocks: list[list[KVCacheBlock]], + allocated_blocks: list[list[KVCacheBlock]], + ) -> list[int]: + num_new_blocks = self.allocator.get_num_new_blocks( + total_num_tokens, + num_computed_tokens, + new_computed_blocks, + allocated_blocks, + self.group_ids, + ) + return [num_new_blocks[0]] + + def cache_blocks( + self, + request: Request, + blocks: list[list[KVCacheBlock]], + num_computed_tokens: int, + num_new_tokens: int, + num_cached_blocks: list[int], + hash_fn: Any, + ) -> list[int]: + return self.allocator.cache_blocks( + request, + blocks[0], + num_computed_tokens, + num_new_tokens, + num_cached_blocks[0], + hash_fn, + ) + + def sort_by_eviction_order( + self, + blocks: list[list[KVCacheBlock]], + ) -> Iterable[KVCacheBlock]: + return self.allocator.sort_by_eviction_order(blocks[0]) + + +class FullAndSwaMemoryAllocator(HybridMemoryAllocator): + """Memory allocator for full and sliding window attention. + + For example, models like Gemma 2 (1:1 full/swa) and Gemma 3 (1:5 full/swa) + use this allocator. + """ + + def __init__( + self, + block_pool: BlockPool, + full_attn_allocator: FullAttentionAllocator, + full_attn_group_ids: tuple[int, ...], + swa_allocator: SlidingWindowAllocator, + swa_group_ids: tuple[int, ...], + ): + super().__init__(block_pool) + self.full_attn_allocator = full_attn_allocator + self.full_attn_group_ids = full_attn_group_ids + self.swa_allocator = swa_allocator + self.swa_group_ids = swa_group_ids + + self.all_group_ids = sorted(full_attn_group_ids + swa_group_ids) + self.num_groups = len(self.all_group_ids) + self.block_size = full_attn_allocator.block_size + if self.block_size != swa_allocator.block_size: + raise ValueError( + f"The block size of full attention ({self.block_size}) and " + f"sliding window attention ({swa_allocator.block_size}) must " + "be the same.") + + def _get_block_hashes( + self, + request: Request, + hash_fn: Any, + ) -> list[list[BlockHashType]]: + # The full attention and sliding window attention use the same block + # size. + block_hashes = self.full_attn_allocator.get_block_hashes( + request, hash_fn) + # TODO(woosuk): Optimize this. + return [block_hashes] * self.num_groups + + def find_longest_cache_hit( + self, + block_hashes: list[list[BlockHashType]], + num_tokens: int, + ) -> tuple[list[list[KVCacheBlock]], int]: + # Because the full attention and sliding window attention use the same + # block size, we can just use the block hashes for any group. + block_hashes = block_hashes[0] + if len(block_hashes) * self.block_size == num_tokens: + block_hashes = block_hashes[:-1] + + # First, find the longest cache hit for full attention. + full_attn_blocks, num_full_attn_tokens = ( + self.full_attn_allocator.find_longest_cache_hit( + block_hashes, self.full_attn_group_ids)) + num_full_attn_blocks = num_full_attn_tokens // self.block_size + if num_full_attn_blocks == 0: + # No cache hit. + return [[]] * self.num_groups, 0 + + # Next, find the cache hit for sliding window attention WITHIN the + # cache hit of full attention. + block_hashes = block_hashes[:num_full_attn_blocks] + swa_attn_blocks, num_swa_attn_tokens = ( + self.swa_allocator.find_longest_cache_hit(block_hashes, + self.swa_group_ids)) + num_swa_attn_blocks = num_swa_attn_tokens // self.block_size + if num_swa_attn_blocks == 0: + # No cache hit. + return [[]] * self.num_groups, 0 + + # Truncate the full attention cache hit to the length of the + # sliding window cache hit. + num_blocks = num_swa_attn_blocks + num_computed_tokens = num_swa_attn_tokens + + combined_blocks: list[list[KVCacheBlock]] = [] + for group_id in self.all_group_ids: + if group_id in self.full_attn_group_ids: + combined_blocks.append(full_attn_blocks[group_id][:num_blocks]) + else: + # We don't need `[:num_blocks]` here. + combined_blocks.append(swa_attn_blocks[group_id]) + return combined_blocks, num_computed_tokens + + def remove_skipped_blocks( + self, + blocks: list[list[KVCacheBlock]], + num_computed_tokens: int, + ) -> list[KVCacheBlock]: + return self.swa_allocator.remove_skipped_blocks( + blocks, self.swa_group_ids, num_computed_tokens) + + def _get_num_new_blocks( + self, + total_num_tokens: int, + num_computed_tokens: int, + new_computed_blocks: list[list[KVCacheBlock]], + allocated_blocks: list[list[KVCacheBlock]], + ) -> list[int]: + # OPTIMIZATION(woosuk): + group_id = self.full_attn_group_ids[0] + num_new_blocks = self.full_attn_allocator.get_num_new_blocks( + total_num_tokens, + num_computed_tokens, + new_computed_blocks[group_id], + allocated_blocks[group_id], + group_ids=(group_id, ), + ) + return [num_new_blocks] * self.num_groups + + def sort_by_eviction_order( + self, + blocks: list[list[KVCacheBlock]], + ) -> Iterable[KVCacheBlock]: + num_blocks = len(blocks[0]) + group_ids = self.all_group_ids + return [ + block for i in reversed(range(num_blocks)) + for group_id in group_ids + if not (block := blocks[group_id][i]).is_null + ] + + +def get_hybrid_allocator( + kv_cache_config: KVCacheConfig, + block_pool: BlockPool, +) -> HybridMemoryAllocator: + num_groups = len(kv_cache_config.kv_cache_groups) + if num_groups == 1: + kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec + allocator = get_specialized_allocator( + block_pool=block_pool, + kv_cache_spec=kv_cache_spec, + ) + return SingleMemoryAllocator( + block_pool=block_pool, + allocator=allocator, + ) + else: + raise NotImplementedError diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 33761cf7f9c0..b4a13d5913ad 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,15 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 - -from collections import defaultdict -from collections.abc import Iterable +import itertools from typing import Optional from vllm.logger import init_logger -from vllm.utils import cdiv, sha256 +from vllm.utils import sha256 from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, - hash_request_tokens) -from vllm.v1.core.specialized_manager import get_specialized_manager +from vllm.v1.core.hybrid_allocator import get_hybrid_allocator +from vllm.v1.core.kv_cache_utils import KVCacheBlock from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus @@ -28,58 +25,34 @@ def __init__( num_preallocate_tokens: int = 64, log_stats: bool = False, ) -> None: - assert len(kv_cache_config.kv_cache_groups) == 1, ( - "KVCacheManager does not support hybrid models with more than 1 " - "kv cache group") - kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec - self.block_size = kv_cache_spec.block_size self.num_gpu_blocks = kv_cache_config.num_blocks self.max_model_len = max_model_len - self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size) self.enable_caching = enable_caching self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash # FIXME: make prefix cache stats conditional on log_stats self.log_stats = log_stats - # NOTE(woosuk): To avoid frequent block allocation, we preallocate some - # blocks for each request. For example, when a request reaches the end - # of its block table, we preallocate N blocks in advance. This way, we - # reduce the overhead of updating free_block_ids and ref_cnts for each - # request every step (at the cost of some memory waste). - # NOTE(woosuk): This is different from the "lookahead" slots since this - # does not guarantee that the request always has N empty blocks. After - # the request gets N empty blocks, it starts to use the blocks without - # further allocation. When it uses up all the N empty blocks, it gets - # N new empty blocks. - self.num_preallocate_tokens = num_preallocate_tokens - self.num_preallocate_blocks = cdiv(num_preallocate_tokens, - self.block_size) + self.prefix_cache_stats = PrefixCacheStats() + self.num_preallocate_tokens = num_preallocate_tokens self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching) - self.specialized_manager = get_specialized_manager( - kv_cache_spec=kv_cache_spec, - block_pool=self.block_pool, - ) - # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. - self.req_to_blocks: defaultdict[str, - list[KVCacheBlock]] = defaultdict(list) - - # Mapping from request ID to kv block hashes. - # This is to avoid recomputing the block hashes for each call of - # `get_computed_blocks` or `allocate_slots`. - self.req_to_block_hashes: defaultdict[ - str, list[BlockHashType]] = defaultdict(list) + self.req_to_blocks: dict[str, list[list[KVCacheBlock]]] = {} # {req_id: The number of cached blocks for this given request} # This is used to track the number of cached blocks for each request. # This is only used to track the RUNNING requests, we do not track the # data for reempted ones. - self.num_cached_block: dict[str, int] = {} - self.prefix_cache_stats = PrefixCacheStats() + self.num_cached_blocks: dict[str, list[int]] = {} + + self.allocator = get_hybrid_allocator( + kv_cache_config=kv_cache_config, + block_pool=self.block_pool, + ) + self.num_groups = len(kv_cache_config.kv_cache_groups) @property def usage(self) -> float: @@ -101,7 +74,9 @@ def make_prefix_cache_stats(self) -> PrefixCacheStats: return stats def get_computed_blocks( - self, request: Request) -> tuple[list[KVCacheBlock], int]: + self, + request: Request, + ) -> tuple[list[list[KVCacheBlock]], int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -117,56 +92,31 @@ def get_computed_blocks( # Prefix caching is disabled. return [], 0 - # The block hashes for the request may already be computed - # if the scheduler has tried to schedule the request before. - block_hashes = self.req_to_block_hashes[request.request_id] - if not block_hashes: - block_hashes = hash_request_tokens(self.caching_hash_fn, - self.block_size, request) - self.req_to_block_hashes[request.request_id] = block_hashes - + block_hashes = self.allocator.get_block_hashes(request, + self.caching_hash_fn) self.prefix_cache_stats.requests += 1 - # When the request requires prompt logprobs, we skip prefix caching. + # If the request requires prompt logprobs, we skip prefix caching. if request.sampling_params.prompt_logprobs is not None: return [], 0 - if len(block_hashes) * self.block_size == request.num_tokens: - # When prompt length is divisible by the block size and all - # blocks are cached, we need to recompute the last token. This - # have to be achieved by re-computing an entire block because - # allocate_slots() assumes num_computed_tokens is always a - # multiple of the block size. To achieve this, remove the last - # block hash from the block_hashes for find_longest_cache_hit - # This limitation can potentially be removed in the future to - # slightly improve the performance. - last_block_hash = block_hashes.pop() - else: - last_block_hash = None - - computed_blocks = ( - self.specialized_manager.find_longest_cache_hit(block_hashes)) - self.prefix_cache_stats.queries += len(block_hashes) - self.prefix_cache_stats.hits += len(computed_blocks) - - if last_block_hash is not None: - # Add back the last block hash if it was removed. - # NOTE: Because block_hashes is cached in req_to_block_hashes, - # we shouldn't modify it directly. - block_hashes.append(last_block_hash) - - # NOTE(woosuk): Since incomplete blocks are not eligible for - # sharing, `num_computed_tokens` is always a multiple of - # `block_size`. - num_computed_tokens = len(computed_blocks) * self.block_size + num_tokens = request.num_tokens + computed_blocks, num_computed_tokens = ( + self.allocator.find_longest_cache_hit(block_hashes, num_tokens)) + + self.prefix_cache_stats.queries += num_tokens + self.prefix_cache_stats.hits += num_computed_tokens + print(f"computed_blocks: {computed_blocks}") return computed_blocks, num_computed_tokens def allocate_slots( self, request: Request, - num_tokens: int, - new_computed_blocks: Optional[list[KVCacheBlock]] = None, + num_input_tokens: int, + num_draft_tokens: int = 0, + new_computed_tokens: int = 0, + new_computed_blocks: Optional[list[list[KVCacheBlock]]] = None, num_lookahead_tokens: int = 0, - ) -> Optional[list[KVCacheBlock]]: + ) -> Optional[list[list[KVCacheBlock]]]: """Add slots for a request with new tokens to append. Args: @@ -194,12 +144,16 @@ def allocate_slots( Returns: A list of new allocated blocks. """ - if num_tokens == 0: - raise ValueError("num_tokens must be greater than 0") + assert num_input_tokens + num_draft_tokens > 0 - new_computed_blocks = new_computed_blocks or [] + if new_computed_blocks is None: + new_computed_blocks = [[] for _ in range(self.num_groups)] + assert new_computed_tokens == 0 - req_blocks = self.req_to_blocks[request.request_id] + req_blocks = self.req_to_blocks.get(request.request_id) + if req_blocks is None: + req_blocks = [[] for _ in range(self.num_groups)] + self.req_to_blocks[request.request_id] = req_blocks # Free the blocks that are skipped during the attention computation # (e.g., tokens outside the sliding window). @@ -207,92 +161,54 @@ def allocate_slots( # insufficient free blocks. # Should call this function before allocating new blocks to reduce # the number of evicted blocks. - removed_blocks = self.specialized_manager.remove_skipped_blocks( + removed_blocks = self.allocator.remove_skipped_blocks( req_blocks, request.num_computed_tokens) self.block_pool.free_blocks(removed_blocks) - # The number of computed tokens is the number of computed tokens plus - # the new prefix caching hits - num_computed_tokens = (request.num_computed_tokens + - len(new_computed_blocks) * self.block_size) - num_required_blocks = cdiv( - num_computed_tokens + num_tokens + num_lookahead_tokens, - self.block_size) - num_new_blocks = (num_required_blocks - len(req_blocks) - - len(new_computed_blocks)) - - # If a computed block of a request is an eviction candidate (in the - # free queue and ref_cnt == 0), it cannot be counted as a free block - # when allocating this request. - num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks - if blk.ref_cnt == 0) - if (num_new_blocks > self.block_pool.get_num_free_blocks() - - num_evictable_computed_blocks): - # Cannot allocate new blocks + num_computed_tokens = request.num_computed_tokens + new_computed_tokens + num_preallocate_tokens = max(self.num_preallocate_tokens, + num_lookahead_tokens) + # Should not exceed the maximum number of blocks per request. + # This is especially because the block table has the shape + # [..., max_num_blocks_per_req]. + total_num_tokens = min( + num_computed_tokens + num_input_tokens + num_draft_tokens + + num_preallocate_tokens, self.max_model_len) + + new_blocks = self.allocator.allocate_blocks( + total_num_tokens, + num_computed_tokens, + req_blocks, + new_computed_blocks, + ) + if new_blocks is None: + # Cannot allocate new blocks. return None - # Touch the computed blocks to make sure they won't be evicted. - if self.enable_caching: - self.block_pool.touch(new_computed_blocks) - else: - assert not new_computed_blocks, ( - "Computed blocks should be empty when " - "prefix caching is disabled") - - # Append the new computed blocks to the request blocks until now to - # avoid the case where the new blocks cannot be allocated. - req_blocks.extend(new_computed_blocks) - - # Start to handle new blocks - - if num_new_blocks <= 0: - # No new block is needed. - new_blocks = [] - else: - # Get new blocks from the free block pool considering - # preallocated blocks. - num_preallocate_blocks = max( - 0, self.num_preallocate_blocks - - num_lookahead_tokens // self.block_size) - num_new_blocks = min( - num_new_blocks + num_preallocate_blocks, - self.block_pool.get_num_free_blocks(), - # Should not exceed the maximum number of blocks per request. - # This is especially because the block table has the shape - # [..., max_num_blocks_per_req]. - self.max_num_blocks_per_req - len(req_blocks), - ) - assert num_new_blocks > 0 - - # Concatenate the computed block IDs and the new block IDs. - new_blocks = self.block_pool.get_new_blocks(num_new_blocks) - req_blocks.extend(new_blocks) - + # Add the new computed blocks and new blocks to the request. + for group_id in range(self.num_groups): + req_blocks[group_id].extend(new_computed_blocks[group_id]) + req_blocks[group_id].extend(new_blocks[group_id]) if not self.enable_caching: return new_blocks - # Use `new_computed_blocks` for a new request, and `num_cached_block` + # Use `new_computed_blocks` for a new request, and `num_cached_blocks` # for a running request. - num_cached_blocks = self.num_cached_block.get(request.request_id, - len(new_computed_blocks)) - # Speculated tokens might be rejected in the future, so we does - # not cache any speculated tokens. We only cache blocks with - # generated (accepted) tokens. - num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( - request.spec_token_ids)) // self.block_size - - self.block_pool.cache_full_blocks( - request=request, - blocks=req_blocks, - block_hashes=self.req_to_block_hashes[request.request_id], - num_cached_blocks=num_cached_blocks, - num_full_blocks=num_full_blocks_after_append, - block_size=self.block_size, - hash_fn=self.caching_hash_fn, + num_cached_blocks = self.num_cached_blocks.get( + request.request_id, + [len(blocks) for blocks in new_computed_blocks], ) - - self.num_cached_block[ - request.request_id] = num_full_blocks_after_append + # NOTE(woosuk): Since draft tokens can be rejected, we should not cache + # any blocks including draft tokens. + num_cached_blocks = self.allocator.cache_blocks( + request, + req_blocks, + num_computed_tokens, + num_input_tokens, # No draft tokens or lookahead tokens + num_cached_blocks, + self.caching_hash_fn, + ) + self.num_cached_blocks[request.request_id] = num_cached_blocks return new_blocks def free(self, request: Request) -> None: @@ -303,16 +219,18 @@ def free(self, request: Request) -> None: Args: request: The request to free the blocks. """ - # Default to [] in case a request is freed (aborted) before alloc. - blocks = self.req_to_blocks.pop(request.request_id, []) - ordered_blocks: Iterable[KVCacheBlock] = blocks - if self.enable_caching: - # Free blocks in reverse order so that the tail blocks are - # freed first. - ordered_blocks = reversed(blocks) + # Default to None in case a request is freed (aborted) before alloc. + self.num_cached_blocks.pop(request.request_id, None) + blocks = self.req_to_blocks.pop(request.request_id, None) + if blocks is None: + return + if self.enable_caching: + ordered_blocks = self.allocator.sort_by_eviction_order(blocks) + else: + # When caching is disabled, free the blocks in any order. + ordered_blocks = itertools.chain.from_iterable(blocks) self.block_pool.free_blocks(ordered_blocks) - self.num_cached_block.pop(request.request_id, None) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -382,4 +300,4 @@ def free_block_hashes(self, request: Request) -> None: NOTE: Unlike `free`, this method should be called only when the request is finished, not when it is preempted. """ - self.req_to_block_hashes.pop(request.request_id, None) + self.allocator.req_to_block_hashes.pop(request.request_id, None) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index bd0e01d045d1..dad67c7bee4d 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -124,6 +124,9 @@ class KVCacheBlock: prev_free_block: Optional["KVCacheBlock"] = None next_free_block: Optional["KVCacheBlock"] = None + # Whether the block is a null block. + is_null: bool = False + def incr_ref(self): self.ref_cnt += 1 @@ -152,12 +155,19 @@ def __repr__(self) -> str: next_block_id = self.next_free_block.block_id \ if self.next_free_block else None return (f"KVCacheBlock(block_id={self.block_id}, " + f"is_null={self.is_null}, " f"ref_cnt={self.ref_cnt}, " f"_block_hash={self._block_hash}, " f"prev_free_block={prev_block_id}, " f"next_free_block={next_block_id})") +@dataclass +class BlocksPerGroup: + + blocks: list[list[KVCacheBlock]] + + class FreeKVCacheBlockQueue: """This class organizes a list of KVCacheBlock objects to a doubly linked list of free blocks. We implement this class instead of using Python diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index dc0d2d59fea7..86a15a32a6e2 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -25,7 +25,7 @@ class NewRequestData: mm_hashes: list[str] mm_positions: list[PlaceholderRange] sampling_params: SamplingParams - block_ids: list[int] + block_ids: list[list[int]] num_computed_tokens: int lora_request: Optional[LoRARequest] @@ -33,7 +33,7 @@ class NewRequestData: def from_request( cls, request: Request, - block_ids: list[int], + block_ids: list[list[int]], ) -> NewRequestData: return cls( req_id=request.request_id, @@ -58,7 +58,7 @@ class CachedRequestData: # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool new_token_ids: list[int] - new_block_ids: list[int] + new_block_ids: list[list[int]] num_computed_tokens: int @classmethod @@ -67,7 +67,7 @@ def from_request( request: Request, resumed_from_preemption: bool, new_token_ids: list[int], - new_block_ids: list[int], + new_block_ids: list[list[int]], ) -> CachedRequestData: return cls( req_id=request.request_id, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a81574875a5c..16a9189aaa05 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -2,6 +2,7 @@ from __future__ import annotations +import itertools import time from collections import deque from collections.abc import Iterable @@ -144,7 +145,7 @@ def schedule(self) -> SchedulerOutput: # uses structured decoding. structured_output_request_ids: dict[str, int] = {} - req_to_new_block_ids: dict[str, list[int]] = {} + req_to_new_block_ids: dict[str, list[list[int]]] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. @@ -165,7 +166,8 @@ def schedule(self) -> SchedulerOutput: req_index += 1 continue - num_new_tokens = (request.num_tokens_with_spec - + num_draft_tokens = len(request.draft_token_ids) + num_new_tokens = (request.num_tokens + num_draft_tokens - request.num_computed_tokens) if (0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens): @@ -196,7 +198,8 @@ def schedule(self) -> SchedulerOutput: while True: new_blocks = self.kv_cache_manager.allocate_slots( request, - num_new_tokens, + num_new_tokens - num_draft_tokens, + num_draft_tokens=num_draft_tokens, num_lookahead_tokens=self.num_lookahead_tokens) if new_blocks is None: # The request cannot be scheduled. @@ -233,7 +236,7 @@ def schedule(self) -> SchedulerOutput: # cycle to fill in the bitmask, which could be a big no-op. structured_output_request_ids[request.request_id] = req_index req_to_new_block_ids[request.request_id] = [ - b.block_id for b in new_blocks + [b.block_id for b in blocks] for blocks in new_blocks ] num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens @@ -330,7 +333,11 @@ def schedule(self) -> SchedulerOutput: new_encoder_budget = encoder_budget new_blocks = self.kv_cache_manager.allocate_slots( - request, num_new_tokens, computed_blocks) + request, + num_new_tokens, + new_computed_tokens=num_computed_tokens, + new_computed_blocks=computed_blocks, + num_lookahead_tokens=self.num_lookahead_tokens) if new_blocks is None: # The request cannot be scheduled. break @@ -355,9 +362,9 @@ def schedule(self) -> SchedulerOutput: if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_block_ids[request.request_id] = [ - b.block_id for b in computed_blocks + new_blocks - ] + req_to_new_block_ids[request.request_id] = [[ + b.block_id for b in itertools.chain(b1, b2) + ] for b1, b2 in zip(computed_blocks, new_blocks)] num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -463,7 +470,7 @@ def _make_cached_request_data( request: Request, num_scheduled_tokens: int, num_scheduled_spec_tokens: int, - new_block_ids: list[int], + new_block_ids: list[list[int]], resumed_from_preemption: bool, ) -> CachedRequestData: # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py index 7a8a98361c7e..1b9e6295f982 100644 --- a/vllm/v1/core/specialized_manager.py +++ b/vllm/v1/core/specialized_manager.py @@ -1,16 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod +from collections.abc import Iterable, Sequence +from typing import Any from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock +from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, + hash_request_tokens) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, SlidingWindowSpec) +from vllm.v1.request import Request -class SpecializedManager(ABC): +class SpecializedAllocator(ABC): """ - An abstract base class for specialized managers that handle the kv + An abstract base class for specialized allocators that handle the kv cache management logic of different attention layers. """ @@ -20,19 +24,28 @@ def __init__( block_pool: BlockPool, ) -> None: """ - Initializes the SpecializedManager. + Initializes the SpecializedAllocator. Args: - kv_cache_spec: The kv_cache_spec for this manager. + kv_cache_spec: The kv_cache_spec for this allocator. block_pool: The block pool. """ - - self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool + self.block_size = kv_cache_spec.block_size + + def get_block_hashes( + self, + request: Request, + hash_fn: Any, + ) -> list[BlockHashType]: + return hash_request_tokens(hash_fn, self.block_size, request) @abstractmethod def find_longest_cache_hit( - self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: + self, + block_hashes: list[BlockHashType], + group_ids: Sequence[int], + ) -> tuple[dict[int, list[KVCacheBlock]], int]: """ Get the longest cache hit prefix of the blocks. If no cache hit is found, return an empty list. @@ -41,16 +54,19 @@ def find_longest_cache_hit( block_hashes: The block hashes of the request. Returns: A list of cached blocks with skipped blocks replaced by null block. - For example, sliding window manager should return a list like + For example, sliding window allocator should return a list like [NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and sliding window 8. """ - raise NotImplementedError @abstractmethod - def remove_skipped_blocks(self, blocks: list[KVCacheBlock], - num_computed_tokens: int) -> list[KVCacheBlock]: + def remove_skipped_blocks( + self, + blocks: list[list[KVCacheBlock]], + group_ids: Sequence[int], + num_computed_tokens: int, + ) -> Iterable[KVCacheBlock]: """ Remove the blocks that are no longer needed from `blocks`. The removed blocks should be replaced by null_block. Return the removed blocks in @@ -65,29 +81,107 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock], """ raise NotImplementedError + @abstractmethod + def get_num_new_blocks( + self, + total_num_tokens: int, + num_computed_tokens: int, + new_computed_blocks: list[list[KVCacheBlock]], + allocated_blocks: list[list[KVCacheBlock]], + group_ids: Sequence[int], + ) -> dict[int, int]: + raise NotImplementedError + + def cache_blocks( + self, + request: Request, + blocks: list[KVCacheBlock], + block_hashes: list[BlockHashType], + num_computed_tokens: int, + num_new_tokens: int, + num_cached_blocks: int, + hash_fn: Any, + ) -> int: + num_full_blocks = (num_computed_tokens + + num_new_tokens) // self.block_size + self.block_pool.cache_full_blocks( + request=request, + blocks=blocks, + block_hashes=block_hashes, + num_cached_blocks=num_cached_blocks, + num_full_blocks=num_full_blocks, + block_size=self.block_size, + hash_fn=hash_fn, + ) + return num_full_blocks + + @abstractmethod + def sort_by_eviction_order( + self, + blocks: list[KVCacheBlock], + ) -> Iterable[KVCacheBlock]: + raise NotImplementedError + -class FullAttentionManager(SpecializedManager): +class FullAttentionAllocator(SpecializedAllocator): def find_longest_cache_hit( - self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: - computed_blocks: list[KVCacheBlock] = [] + self, + block_hashes: list[BlockHashType], + group_ids: Sequence[int], + ) -> tuple[dict[int, list[KVCacheBlock]], int]: + computed_blocks: dict[int, list[KVCacheBlock]] = { + group_id: [] + for group_id in group_ids + } for block_hash in block_hashes: # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. - if cached_block := self.block_pool.get_cached_block(block_hash): - computed_blocks.append(cached_block) + cached_blocks = self.block_pool.get_cached_block(block_hash) + if (cached_blocks and all(group_id in cached_blocks + for group_id in group_ids)): + for group_id in group_ids: + computed_blocks[group_id].append(cached_blocks[group_id]) else: break - return computed_blocks + num_computed_blocks = len(computed_blocks[group_ids[0]]) + num_computed_tokens = num_computed_blocks * self.block_size + return computed_blocks, num_computed_tokens - def remove_skipped_blocks(self, blocks: list[KVCacheBlock], - num_computed_tokens: int) -> list[KVCacheBlock]: - # No need to remove blocks for full attention. + def remove_skipped_blocks( + self, + blocks: list[list[KVCacheBlock]], + group_ids: Sequence[int], + num_computed_tokens: int, + ) -> Iterable[KVCacheBlock]: + # Full attention skips no blocks. return [] + def get_num_new_blocks( + self, + total_num_tokens: int, + num_computed_tokens: int, + new_computed_blocks: list[list[KVCacheBlock]], + allocated_blocks: list[list[KVCacheBlock]], + group_ids: Sequence[int], + ) -> dict[int, int]: + num_required_blocks = cdiv(total_num_tokens, self.block_size) + num_new_blocks: dict[int, int] = {} + for group_id in group_ids: + num_new_blocks[group_id] = (num_required_blocks - + len(allocated_blocks[group_id]) - + len(new_computed_blocks[group_id])) + return num_new_blocks + + def sort_by_eviction_order( + self, + blocks: list[KVCacheBlock], + ) -> Iterable[KVCacheBlock]: + return reversed(blocks) -class SlidingWindowManager(SpecializedManager): + +class SlidingWindowAllocator(SpecializedAllocator): def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool): @@ -100,37 +194,54 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, self._null_block = block_pool.null_block def find_longest_cache_hit( - self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: + self, + block_hashes: list[BlockHashType], + group_ids: Sequence[int], + ) -> tuple[dict[int, list[KVCacheBlock]], int]: # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to # optimize the time complexity from O(len(block_hashes)) to # O(len(block_hashes) / sliding_window_contiguous_blocks + # sliding_window_contiguous_blocks), # which is good for low cache hit rate scenarios. - computed_blocks = [self._null_block] * len(block_hashes) + computed_blocks = { + group_id: [self._null_block] * len(block_hashes) + for group_id in group_ids + } num_contiguous_blocks = 0 # Search from right to left and early stop when a match is found. for i in range(len(block_hashes) - 1, -1, -1): - if cached_block := self.block_pool.get_cached_block( - block_hashes[i]): - computed_blocks[i] = cached_block - num_contiguous_blocks += 1 - if (num_contiguous_blocks - >= self.sliding_window_contiguous_blocks): - # Trim the trailing blocks. - # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] - # when sliding_window_contiguous_blocks=2. - del computed_blocks[i + num_contiguous_blocks:] - return computed_blocks - else: + block_hash = block_hashes[i] + cached_blocks = self.block_pool.get_cached_block(block_hash) + if (cached_blocks is None or any(group_id not in cached_blocks + for group_id in group_ids)): num_contiguous_blocks = 0 + continue + + for group_id in group_ids: + computed_blocks[group_id][i] = cached_blocks[group_id] + num_contiguous_blocks += 1 + if (num_contiguous_blocks + >= self.sliding_window_contiguous_blocks): + # Trim the trailing blocks. + # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] + # when sliding_window_contiguous_blocks=2. + for group_id in group_ids: + del computed_blocks[group_id][i + num_contiguous_blocks:] + return computed_blocks + # The first `num_contiguous_blocks` is a cache hit even if # `num_contiguous_blocks < sliding_window_contiguous_blocks`. - del computed_blocks[num_contiguous_blocks:] + for group_id in group_ids: + del computed_blocks[group_id][num_contiguous_blocks:] return computed_blocks - def remove_skipped_blocks(self, blocks: list[KVCacheBlock], - num_computed_tokens: int) -> list[KVCacheBlock]: + def remove_skipped_blocks( + self, + blocks: dict[int, list[KVCacheBlock]], + group_ids: Sequence[int], + num_computed_tokens: int, + ) -> Iterable[KVCacheBlock]: # Remove the blocks that are no longer be in the sliding window and # skipped during the attention computation. last_useful_token = num_computed_tokens - self.sliding_window + 1 @@ -138,24 +249,54 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock], removed_blocks: list[KVCacheBlock] = [] for i in range(last_useful_block - 1, -1, -1): - if blocks[i] == self._null_block: - # If the block is already a null block, the blocks before it - # should also have been set to null blocks by the previous calls - # to this function. + met_null_block = False + for group_id in group_ids: + block = blocks[group_id][i] + if block.is_null: + # If the block is already a null block, the blocks before it + # should also have been set to null blocks by the previous + # calls to this function. + met_null_block = True + break + removed_blocks.append(block) + blocks[group_id][i] = self._null_block + if met_null_block: break - removed_blocks.append(blocks[i]) - blocks[i] = self._null_block return removed_blocks + def get_num_new_blocks( + self, + total_num_tokens: int, + num_computed_tokens: int, + new_computed_blocks: list[list[KVCacheBlock]], + allocated_blocks: list[list[KVCacheBlock]], + group_ids: Sequence[int], + ) -> dict[int, int]: + num_required_blocks = cdiv(total_num_tokens, self.block_size) + num_new_blocks: dict[int, int] = {} + for group_id in group_ids: + num_new_blocks[group_id] = (num_required_blocks - + len(allocated_blocks[group_id]) - + len(new_computed_blocks[group_id])) + return num_new_blocks + + def sort_by_eviction_order( + self, + blocks: list[KVCacheBlock], + ) -> Iterable[KVCacheBlock]: + return reversed(blocks) + -spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = { - FullAttentionSpec: FullAttentionManager, - SlidingWindowSpec: SlidingWindowManager, +spec_allocator_map: dict[type[KVCacheSpec], type[SpecializedAllocator]] = { + FullAttentionSpec: FullAttentionAllocator, + SlidingWindowSpec: SlidingWindowAllocator, } -def get_specialized_manager(kv_cache_spec: KVCacheSpec, - block_pool: BlockPool) -> SpecializedManager: - manager_class = spec_manager_map[type(kv_cache_spec)] - manager = manager_class(kv_cache_spec, block_pool) - return manager +def get_specialized_allocator( + kv_cache_spec: KVCacheSpec, + block_pool: BlockPool, +) -> SpecializedAllocator: + allocator_class = spec_allocator_map[type(kv_cache_spec)] + allocator = allocator_class(kv_cache_spec, block_pool) + return allocator diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index fd949264885b..8fe1630616a4 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -19,7 +19,7 @@ class PrefixCacheStats: # The number of requests in this update. requests: int = 0 # The number of queries in these requests. Note that "queries" here - # means the number of blocks that were queried from the cache. + # means the number of tokens that were queried from the cache. queries: int = 0 # The number of hits in these requests. hits: int = 0 diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index a64cb97e0123..8f67e95b406d 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -30,7 +30,7 @@ class CachedRequestState: sampling_params: SamplingParams generator: Optional[torch.Generator] - block_ids: list[int] + block_ids: list[list[int]] num_computed_tokens: int output_token_ids: list[int]