From e496ab5c32eb9c5c423d981d8295ecd1e22e321d Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Sat, 2 Mar 2024 03:50:01 -0500 Subject: [PATCH] Add Automatic Prefix Caching (#2762) Co-authored-by: ElizaWszola Co-authored-by: Michael Goin --- benchmarks/benchmark_throughput.py | 30 ++- docs/source/models/engine_args.rst | 4 + examples/offline_inference_with_prefix.py | 11 +- tests/prefix_caching/test_prefix_caching.py | 103 ++++--- tests/test_cache_block_hashing.py | 76 ++++++ vllm/block.py | 14 +- vllm/config.py | 2 + vllm/core/block_manager.py | 285 +++++++++++++++----- vllm/core/evictor.py | 161 +++++++++++ vllm/core/scheduler.py | 15 +- vllm/engine/arg_utils.py | 9 +- vllm/engine/async_llm_engine.py | 14 +- vllm/engine/llm_engine.py | 26 +- vllm/entrypoints/api_server.py | 6 +- vllm/entrypoints/llm.py | 14 +- vllm/prefix.py | 87 ------ vllm/sequence.py | 23 +- vllm/worker/model_runner.py | 30 ++- 18 files changed, 618 insertions(+), 292 deletions(-) create mode 100644 tests/test_cache_block_hashing.py create mode 100644 vllm/core/evictor.py delete mode 100644 vllm/prefix.py diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 1ad502526c97c..51c1a6540a451 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -73,21 +73,21 @@ def run_vllm( enforce_eager: bool, kv_cache_dtype: str, device: str, + enable_prefix_caching: bool, ) -> float: from vllm import LLM, SamplingParams - llm = LLM( - model=model, - tokenizer=tokenizer, - quantization=quantization, - tensor_parallel_size=tensor_parallel_size, - seed=seed, - trust_remote_code=trust_remote_code, - dtype=dtype, - max_model_len=max_model_len, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, - device=device, - ) + llm = LLM(model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + device=device, + enable_prefix_caching=enable_prefix_caching) # Add the requests to the engine. for prompt, _, output_len in requests: @@ -211,7 +211,8 @@ def main(args: argparse.Namespace): args.seed, args.n, args.use_beam_search, args.trust_remote_code, args.dtype, args.max_model_len, args.enforce_eager, - args.kv_cache_dtype, args.device) + args.kv_cache_dtype, args.device, + args.enable_prefix_caching) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -302,6 +303,7 @@ def main(args: argparse.Namespace): default="cuda", choices=["cuda"], help='device type for vLLM execution, supporting CUDA only currently.') + parser.add_argument("--enable_prefix_caching", action='store_true') args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst index d89b795149501..9f5f672ae4f34 100644 --- a/docs/source/models/engine_args.rst +++ b/docs/source/models/engine_args.rst @@ -81,6 +81,10 @@ Below, you can find an explanation of every engine argument for vLLM: Token block size for contiguous chunks of tokens. +.. option:: --enable-prefix-caching + + Enables automatic prefix caching + .. option:: --seed Random seed for operations. diff --git a/examples/offline_inference_with_prefix.py b/examples/offline_inference_with_prefix.py index 8ccfb1ceea731..1aa718b88907c 100644 --- a/examples/offline_inference_with_prefix.py +++ b/examples/offline_inference_with_prefix.py @@ -37,20 +37,13 @@ print("-" * 80) -# -1 since the last token can change when concatenating prompts. -prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1 - # The llm.generate call will batch all prompts and send the batch at once if resources allow. # The prefix will only be cached after the first batch is processed, so we need to call generate once # to calculate the prefix and cache it. -outputs = llm.generate(generating_prompts[0], - sampling_params, - prefix_pos=[prefix_pos]) +outputs = llm.generate(generating_prompts[0], sampling_params) # Subsequent batches can leverage the cached prefix -outputs = llm.generate(generating_prompts, - sampling_params, - prefix_pos=[prefix_pos] * len(generating_prompts)) +outputs = llm.generate(generating_prompts, sampling_params) # Print the outputs. You should see the same outputs as before for output in outputs: diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 1e301bedfc21e..7ef8dde7bb8f6 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -4,38 +4,73 @@ """ import pytest -from vllm import LLM, SamplingParams - -prefix = ( - "You are an expert school principal, skilled in effectively managing " - "faculty and staff. Draft 10-15 questions for a potential first grade " - "Head Teacher for my K-12, all-girls', independent school that emphasizes " - "community, joyful discovery, and life-long learning. The candidate is " - "coming in for a first-round panel interview for a 8th grade Math " - "teaching role. They have 5 years of previous teaching experience " - "as an assistant teacher at a co-ed, public school with experience " - "in middle school math teaching. Based on these information, fulfill " - "the following paragraph: ") - - -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) -@pytest.mark.parametrize("max_tokens", [16]) -def test_prefix_caching( - example_prompts, - model: str, - max_tokens: int, +from vllm.core.block_manager import BlockAllocator +from vllm.utils import Device + + +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_blocks", [16]) +def test_block_allocator( + block_size: int, + num_blocks: int, ): - llm = LLM(model=model) - # -1 since the last token can change when concatenating prompts. - prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1 - prompts = [prefix + prompt for prompt in example_prompts] - sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) - outputs_without_prefix = llm.generate(prompts, sampling_params) - outputs_with_prefix = llm.generate(prompts, - sampling_params, - prefix_pos=[prefix_pos] * len(prompts)) - for output_without_prefix, output_with_prefix in zip( - outputs_without_prefix, outputs_with_prefix): - assert (output_without_prefix.outputs[0].token_ids == - output_with_prefix.outputs[0].token_ids) - assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 1 + block_hash = 1 + block_allocator = BlockAllocator(Device.CPU, + block_size, + num_blocks, + enable_caching=True) + + # Allocate two PysicalTokenBlocks with the same hash and check that they are the same PhysicalTokenBlock + first_block = block_allocator.allocate(block_hash, 0) + second_block = block_allocator.allocate(block_hash, 0) + assert (first_block == second_block) + assert (second_block.ref_count == 2) + + # Free the first_block and confirm that the ref_count is correctly decremented on the second block + block_allocator.free(first_block) + assert (second_block.ref_count == 1) + + # Free the second block + block_allocator.free(second_block) + + # Reallocate the first block and confirm that, even after the block had its ref_count go to 0, we still get the same block back + first_block = block_allocator.allocate(block_hash, 0) + assert (first_block == second_block) + assert (first_block.block_hash == block_hash) + + +@pytest.mark.parametrize("num_blocks", [16]) +def test_eviction(num_blocks: int, ): + block_size = 16 + block_allocator = BlockAllocator(Device.CPU, + block_size, + num_blocks, + enable_caching=True) + blocks = [] + + for i in range(num_blocks): + # use i as the block_hash + blocks.append(block_allocator.allocate(i, 0)) + + #Free all blocks + for block in blocks: + block_allocator.free(block) + + # Allocate a new block and confirm that it's the first block freed. I.E The Least Recently Used block + new_block_hash = block_size + new_block = block_allocator.allocate(new_block_hash, 0) + assert (new_block == blocks[0]) + assert (new_block.block_hash == new_block_hash) + + # Reallocate the second in blocks to remove it from the free list + realloc_block_hash = 1 + realloc_block = block_allocator.allocate(realloc_block_hash, 0) + assert (realloc_block == blocks[realloc_block_hash]) + assert (realloc_block.block_hash == realloc_block_hash) + + # Allocate a new block and confirm that it's not the realloc_block, since the realloc_block shouldn't be in the free list + new_block_hash = block_size + 1 + new_block = block_allocator.allocate(new_block_hash, 0) + assert (realloc_block != new_block) + assert (new_block.block_hash == new_block_hash) + assert (new_block.block_number == 2) diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py new file mode 100644 index 0000000000000..7c4ade7f8c8ed --- /dev/null +++ b/tests/test_cache_block_hashing.py @@ -0,0 +1,76 @@ +"""Test hashing of cache blocks. + +Run `pytest tests/test_cache_block_hashing.py`. +""" +import pytest + +from vllm.transformers_utils.tokenizer import TokenizerGroup +from vllm.sequence import Sequence + +# Make two prefixes with different first blocks. +prefix_start = [("You are an expert"), ("You are a")] +prefix_common = ( + " school principal, skilled in effectively managing " + "faculty and staff. Draft 10-15 questions for a potential first grade " + "Head Teacher for my K-12, all-girls', independent school that emphasizes " + "community, joyful discovery, and life-long learning. The candidate is " + "coming in for a first-round panel interview for a 8th grade Math " + "teaching role. They have 5 years of previous teaching experience " + "as an assistant teacher at a co-ed, public school with experience " + "in middle school math teaching. Based on this, fulfill " + "the following: ") +prefixes = [start + prefix_common for start in prefix_start] + +# Sample prompts. +sample_prompts = [ + "Hello, my name is", "The president of the United States is", + "The capital of France is", "The future of AI is" +] + + +# Helper function. +def flatten_2d(li): + return [lss for ls in li for lss in ls] + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("max_num_seqs", [256]) +def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int): + + tokenizer = TokenizerGroup( + tokenizer_id="facebook/opt-125m", + enable_lora=False, + max_num_seqs=max_num_seqs, + max_input_length=None, + ) + + hashes = [] + + for prefix in prefixes: + hashes.append([]) + prompts = [prefix + prompt for prompt in sample_prompts] + seq_id = 0 + for prompt in prompts: + hashes[-1].append([]) + prompt_token_ids = tokenizer.encode(prompt) + seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) + + num_blocks = len(prompt_token_ids) // block_size + for idx in range(num_blocks): + hashes[-1][-1].append(seq.hash_of_block(idx)) + + seq_id += 1 + + # Check that hashes made with two prefixes with different first blocks are + # different everywhere. + for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])): + assert (hash0 != hash1) + + # Check that hashes of different prompts made with the same prefix are the + # same until the hashes that contain the prompt. + for hash_pref in hashes: + same_hashes = [tuple(h[:-1]) for h in hash_pref] + different_hashes = [h[-1] for h in hash_pref] + assert (len(set(same_hashes)) == 1) + assert (len(set(different_hashes)) == len(different_hashes)) diff --git a/vllm/block.py b/vllm/block.py index 5fe39ed47b2ff..2cc6b947f2255 100644 --- a/vllm/block.py +++ b/vllm/block.py @@ -5,6 +5,8 @@ _BLANK_TOKEN_ID = -1 +DEFAULT_LAST_ACCESSED_TIME = -1 + class LogicalTokenBlock: """A block that stores a contiguous chunk of tokens from left to right. @@ -55,17 +57,27 @@ def __init__( device: Device, block_number: int, block_size: int, + block_hash: int, + num_hashed_tokens: int, ) -> None: self.device = device self.block_number = block_number self.block_size = block_size + self.block_hash = block_hash + self.num_hashed_tokens = num_hashed_tokens self.ref_count = 0 + self.last_accessed = DEFAULT_LAST_ACCESSED_TIME + + self.computed = False def __repr__(self) -> str: return (f'PhysicalTokenBlock(device={self.device}, ' f'block_number={self.block_number}, ' - f'ref_count={self.ref_count})') + f'num_hashed_tokens={self.num_hashed_tokens}, ' + f'ref_count={self.ref_count}, ' + f'last_accessed={self.last_accessed}, ' + f'computed={self.computed})') # Mapping: logical block number -> physical block. diff --git a/vllm/config.py b/vllm/config.py index ff8536c1aca55..876a439cd1280 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -303,12 +303,14 @@ def __init__( swap_space: int, cache_dtype: str, sliding_window: Optional[int] = None, + enable_prefix_caching: bool = False, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization self.swap_space_bytes = swap_space * _GB self.cache_dtype = cache_dtype self.sliding_window = sliding_window + self.enable_prefix_caching = enable_prefix_caching self._verify_args() self._verify_cache_dtype() diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 3946096d4296a..08d519ab767a9 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -1,10 +1,13 @@ """A block manager that manages token blocks.""" import enum +from itertools import count +from os.path import commonprefix from typing import Dict, List, Optional, Set, Tuple from vllm.block import BlockTable, PhysicalTokenBlock from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device +from vllm.core.evictor import Evictor, EvictionPolicy, make_evictor class BlockAllocator: @@ -15,29 +18,68 @@ class BlockAllocator: the reference count becomes zero, the block is added back to the free list. """ - def __init__( - self, - device: Device, - block_size: int, - num_blocks: int, - ) -> None: + def __init__(self, + device: Device, + block_size: int, + num_blocks: int, + eviction_policy: EvictionPolicy = EvictionPolicy.LRU, + enable_caching: bool = False) -> None: self.device = device self.block_size = block_size self.num_blocks = num_blocks + self.enable_caching = enable_caching + + self.current_num_blocks = 0 + self.cached_blocks: Dict[int, PhysicalTokenBlock] = {} + + # Switch over to FIFO eviction when caching is disabled + if not self.enable_caching: + eviction_policy = EvictionPolicy.FIFO + self.evictor: Evictor = make_evictor(eviction_policy) + + self.default_hash_ctr = count() + + def allocate_block(self, block_hash: int, + num_hashed_tokens: int) -> PhysicalTokenBlock: + if self.current_num_blocks == self.num_blocks: + block = self.evictor.evict() + block.block_hash = block_hash + block.num_hashed_tokens = num_hashed_tokens + return block + block = PhysicalTokenBlock(device=self.device, + block_number=self.current_num_blocks, + block_size=self.block_size, + block_hash=block_hash, + num_hashed_tokens=num_hashed_tokens) + self.current_num_blocks += 1 + return block - # Initialize the free blocks. - self.free_blocks: BlockTable = [] - for i in range(num_blocks): - block = PhysicalTokenBlock(device=device, - block_number=i, - block_size=block_size) - self.free_blocks.append(block) - - def allocate(self) -> PhysicalTokenBlock: - if not self.free_blocks: - raise ValueError("Out of memory! No free blocks are available.") - block = self.free_blocks.pop() - block.ref_count = 1 + def allocate(self, + block_hash: Optional[int] = None, + num_hashed_tokens: int = 0) -> PhysicalTokenBlock: + # If caching is disabled, just allocate a new block and return it + if not self.enable_caching: + block = self.allocate_block(next(self.default_hash_ctr), + num_hashed_tokens) + block.ref_count += 1 + return block + + if block_hash is None: + block_hash = next(self.default_hash_ctr) + if block_hash in self.evictor: + assert block_hash not in self.cached_blocks + block = self.evictor.remove(block_hash) + assert block.ref_count == 0 + self.cached_blocks[block_hash] = block + block.ref_count += 1 + assert block.block_hash == block_hash + return block + if block_hash not in self.cached_blocks: + self.cached_blocks[block_hash] = self.allocate_block( + block_hash, num_hashed_tokens) + block = self.cached_blocks[block_hash] + assert block.block_hash == block_hash + block.ref_count += 1 return block def free(self, block: PhysicalTokenBlock) -> None: @@ -45,10 +87,27 @@ def free(self, block: PhysicalTokenBlock) -> None: raise ValueError(f"Double free! {block} is already freed.") block.ref_count -= 1 if block.ref_count == 0: - self.free_blocks.append(block) + assert block.block_hash not in self.evictor + self.evictor.add(block) + + # If caching is enabled, remove the block from the cached_blocks + if self.enable_caching: + del self.cached_blocks[block.block_hash] def get_num_free_blocks(self) -> int: - return len(self.free_blocks) + return self.num_blocks - self.current_num_blocks + self.evictor.num_blocks + + def contains_block(self, block_hash: int) -> bool: + return block_hash in self.cached_blocks or block_hash in self.evictor + + def update_hash(self, block_hash: int, block: PhysicalTokenBlock): + # If caching is enabled, update the hash of block and the cached_blocks dictionary. + if self.enable_caching: + assert not self.contains_block(block_hash) + old_hash = block.block_hash + block.block_hash = block_hash + del self.cached_blocks[old_hash] + self.cached_blocks[block_hash] = block class AllocStatus(enum.Enum): @@ -75,6 +134,7 @@ def __init__( num_cpu_blocks: int, watermark: float = 0.01, sliding_window: Optional[int] = None, + enable_caching: bool = False, ) -> None: self.block_size = block_size self.num_total_gpu_blocks = num_gpu_blocks @@ -89,11 +149,17 @@ def __init__( self.watermark = watermark assert watermark >= 0.0 + self.enable_caching = enable_caching + self.watermark_blocks = int(watermark * num_gpu_blocks) - self.gpu_allocator = BlockAllocator(Device.GPU, block_size, - num_gpu_blocks) - self.cpu_allocator = BlockAllocator(Device.CPU, block_size, - num_cpu_blocks) + self.gpu_allocator = BlockAllocator(Device.GPU, + block_size, + num_gpu_blocks, + enable_caching=enable_caching) + self.cpu_allocator = BlockAllocator(Device.CPU, + block_size, + num_cpu_blocks, + enable_caching=enable_caching) # Mapping: seq_id -> BlockTable. self.block_tables: Dict[int, BlockTable] = {} @@ -103,9 +169,6 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = len(seq.logical_token_blocks) - if seq_group.prefix is not None and seq_group.prefix.allocated: - num_required_blocks -= seq_group.prefix.get_num_blocks() - if self.block_sliding_window is not None: num_required_blocks = min(num_required_blocks, self.block_sliding_window) @@ -129,36 +192,16 @@ def allocate(self, seq_group: SequenceGroup) -> None: num_prompt_blocks = len(seq.logical_token_blocks) block_table: BlockTable = [] - prefix_block_table: BlockTable = [] - num_prefix_blocks = 0 - - prefix = seq_group.prefix - if prefix is not None and prefix.allocated: - # Prefix has already been allocated. Use the existing block table. - num_prompt_blocks -= prefix.get_num_blocks() - for block in prefix.block_table: - block.ref_count += seq_group.num_seqs() - block_table.append(block) - for logical_idx in range(num_prompt_blocks): if (self.block_sliding_window is not None and logical_idx >= self.block_sliding_window): block = block_table[logical_idx % self.block_sliding_window] else: - block = self.gpu_allocator.allocate() - # Set the reference counts of the token blocks. - block.ref_count = seq_group.num_seqs() + block = self.gpu_allocator.allocate( + seq.hash_of_block(logical_idx), + seq.num_hashed_tokens_of_block(logical_idx)) block_table.append(block) - if prefix is not None and not prefix.allocated: - # Allocate blocks for the prefix, we will compute the prefix's - # KV cache in this run. - num_prefix_blocks = prefix.get_num_blocks() - prefix_block_table = block_table[:num_prefix_blocks] - for block in prefix_block_table: - block.ref_count += 1 - prefix.set_block_table(prefix_block_table) - # Assign the block table for each sequence. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() @@ -170,12 +213,72 @@ def can_append_slot(self, seq_group: SequenceGroup) -> bool: num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) return num_seqs <= num_free_gpu_blocks - def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]: + def _promote_last_block( + self, + seq: Sequence, + last_block: PhysicalTokenBlock, + ) -> PhysicalTokenBlock: + # Compute a new hash for the block so that it can be shared by other Sequences + new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) + + # if new_hash is already in the cached table, then free last_block and return the cached version + if self.gpu_allocator.contains_block(new_hash): + self.gpu_allocator.free(last_block) + return self.gpu_allocator.allocate(new_hash) + else: + self.gpu_allocator.update_hash(new_hash, last_block) + return last_block + + def _is_last_block_full( + self, + seq: Sequence, + ) -> bool: + token_ids_len = len(seq.data.get_token_ids()) + return token_ids_len > 0 and token_ids_len % seq.block_size == 0 + + def _is_last_block( + self, + seq: Sequence, + index: int, + ) -> bool: + return index == len(seq.logical_token_blocks) - 1 + + def _maybe_promote_last_block( + self, + seq: Sequence, + last_block: PhysicalTokenBlock, + ) -> PhysicalTokenBlock: + if self._is_last_block_full(seq): + return self._promote_last_block(seq, last_block) + else: + return last_block + + def _allocate_last_physical_block( + self, + seq: Sequence, + ) -> PhysicalTokenBlock: + block_hash: Optional[int] = None + if (self._is_last_block_full(seq)): + block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) + num_hashed_tokens = seq.num_hashed_tokens_of_block( + len(seq.logical_token_blocks) - 1) + new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens) + if block_hash is None: + assert new_block.ref_count == 1 + return new_block + + def append_slot( + self, + seq: Sequence, + ) -> Optional[Tuple[int, int]]: """Allocate a physical slot for a new token.""" logical_blocks = seq.logical_token_blocks block_table = self.block_tables[seq.seq_id] - + # If we need to allocate a new physical block if len(block_table) < len(logical_blocks): + # Currently this code only supports adding one physical block + assert len(block_table) == len(logical_blocks) - 1 + if (self.block_sliding_window and len(block_table) >= self.block_sliding_window): # reuse a block @@ -184,8 +287,8 @@ def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]: else: # The sequence has a new logical block. # Allocate a new physical block. - block = self.gpu_allocator.allocate() - block_table.append(block) + new_block = self._allocate_last_physical_block(seq) + block_table.append(new_block) return None # We want to append the token to the last physical block. @@ -193,11 +296,15 @@ def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]: assert last_block.device == Device.GPU if last_block.ref_count == 1: # Not shared with other sequences. Appendable. + # If the last block is now complete, promote it to a full block so that it can be shared + new_block = self._maybe_promote_last_block(seq, last_block) + block_table[-1] = new_block return None else: # The last block is shared with other sequences. # Copy on Write: Allocate a new block and copy the tokens. - new_block = self.gpu_allocator.allocate() + new_block = self._allocate_last_physical_block(seq) + block_table[-1] = new_block self.gpu_allocator.free(last_block) return last_block.block_number, new_block.block_number @@ -233,25 +340,18 @@ def can_swap_in(self, seq_group: SequenceGroup) -> bool: def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: # CPU block -> GPU block. - if seq_group.prefix is not None: - # make sure to swap in the prefix first - assert seq_group.prefix.allocated and seq_group.prefix.computed - mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): new_block_table: BlockTable = [] block_table = self.block_tables[seq.seq_id] - if seq_group.prefix is not None: - for block in seq_group.prefix.block_table: - new_block_table.append(block) - block.ref_count += 1 for cpu_block in block_table: if cpu_block in mapping: gpu_block = mapping[cpu_block] gpu_block.ref_count += 1 else: - gpu_block = self.gpu_allocator.allocate() + gpu_block = self.gpu_allocator.allocate( + cpu_block.block_hash, cpu_block.num_hashed_tokens) mapping[cpu_block] = gpu_block new_block_table.append(gpu_block) # Free the CPU block swapped in to GPU. @@ -276,17 +376,12 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: block_table = self.block_tables[seq.seq_id] for gpu_block in block_table: - if (seq_group.prefix is not None - and gpu_block in seq_group.prefix.block_table): - # NOTE: We do not swap out the prefix blocks for now. - self.gpu_allocator.free(gpu_block) - continue - if gpu_block in mapping: cpu_block = mapping[gpu_block] cpu_block.ref_count += 1 else: - cpu_block = self.cpu_allocator.allocate() + cpu_block = self.cpu_allocator.allocate( + gpu_block.block_hash, gpu_block.num_hashed_tokens) mapping[gpu_block] = cpu_block new_block_table.append(cpu_block) # Free the GPU block swapped out to CPU. @@ -328,3 +423,49 @@ def get_num_free_gpu_blocks(self) -> int: def get_num_free_cpu_blocks(self) -> int: return self.cpu_allocator.get_num_free_blocks() + + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + block_table = self.block_tables[seq.seq_id] + for block in block_table: + block.last_accessed = access_time + + def compute_last_full_block_in_seq(self, seq: Sequence): + if seq.seq_id not in self.block_tables: + return + max_full_block = seq.get_len() // seq.block_size - 1 + block_table = self.block_tables[seq.seq_id] + if max_full_block == -1: + return + block_table[max_full_block].computed = True + + def get_all_block_ids_till_computed(self, seq: Sequence) -> List[int]: + if seq.seq_id not in self.block_tables: + return [] + block_table = self.block_tables[seq.seq_id] + for block_idx in reversed(range(len(block_table))): + if block_table[block_idx].computed: + return [b.block_number for b in block_table[:block_idx + 1]] + return [] + + # Can return non-empty result only with prefix caching enabled. + def get_common_computed_block_ids(self, + seq_group: SequenceGroup) -> List[int]: + if not self.enable_caching: + return [] + + ids_list = [ + self.get_all_block_ids_till_computed(seq) + for seq in iter(seq_group.seqs_dict.values()) + ] + return commonprefix([ids for ids in ids_list if ids != []]) + + # We only mark the last full block because with prefix caching, + # all blocks until the marked one are guaranteed to be computed. + def mark_blocks_as_computed(self, seq_group: SequenceGroup): + if self.enable_caching: + for seq in seq_group.seqs_dict.values(): + self.compute_last_full_block_in_seq(seq) diff --git a/vllm/core/evictor.py b/vllm/core/evictor.py new file mode 100644 index 0000000000000..b538ea574b604 --- /dev/null +++ b/vllm/core/evictor.py @@ -0,0 +1,161 @@ +import enum +from typing import Dict, List, Optional +from abc import ABC, abstractmethod, abstractproperty + +from vllm.block import PhysicalTokenBlock + + +class EvictionPolicy(enum.Enum): + """Enum for eviction policy used by make_evictor to instantiate the correct + Evictor subclass. + """ + LRU = enum.auto() + FIFO = enum.auto() + + +class Evictor(ABC): + """The Evictor subclasses should be used by the BlockAllocator class to + handle eviction of freed PhysicalTokenBlocks. + """ + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def __contains__(self, block_hash: int) -> bool: + pass + + @abstractmethod + def evict(self) -> PhysicalTokenBlock: + """Runs the eviction algorithm and returns the evicted block""" + pass + + @abstractmethod + def add(self, block: PhysicalTokenBlock): + """Adds block to the evictor, making it a candidate for eviction""" + pass + + @abstractmethod + def remove(self, block_hash: int) -> PhysicalTokenBlock: + """Simply removes the block with the hash value block_hash from the + evictor. Caller is responsible for making sure that block_hash is contained + in the evictor before calling remove. Should be used to "bring back" blocks + that have been freed but not evicted yet. + """ + pass + + @abstractproperty + def num_blocks(self) -> int: + pass + + +class LRUEvictor(Evictor): + """Evicts in a least-recently-used order using the last_accessed timestamp + that's recorded in the PhysicalTokenBlock. If there are multiple blocks with + the same last_accessed time, then the one with the largest num_hashed_tokens + will be evicted. If two blocks each have the lowest last_accessed time and + highest num_hashed_tokens value, then one will be chose arbitrarily + """ + + def __init__(self): + self.free_table: Dict[int, PhysicalTokenBlock] = {} + + def __contains__(self, block_hash: int) -> bool: + return block_hash in self.free_table + + # TODO: The performance of this evict function can be optimized further. + def evict(self) -> PhysicalTokenBlock: + free_blocks: List[PhysicalTokenBlock] = list(self.free_table.values()) + if len(free_blocks) == 0: + raise ValueError("No usable cache memory left") + + # Find lowest timestamp + lowest_timestamp = free_blocks[0].last_accessed + for block in free_blocks: + if block.last_accessed < lowest_timestamp: + lowest_timestamp = block.last_accessed + + # Find all blocks with the lowest timestamp + least_recent: List[PhysicalTokenBlock] = [] + for block in free_blocks: + if block.last_accessed == lowest_timestamp: + least_recent.append(block) + + # Find highest prefix count per block + highest_num_hashed_tokens = 0 + for block in least_recent: + if block.num_hashed_tokens > highest_num_hashed_tokens: + highest_num_hashed_tokens = block.num_hashed_tokens + + evicted_block: Optional[PhysicalTokenBlock] = None + + # Find the first block with the lowest timestamp + for block in least_recent: + if block.num_hashed_tokens == highest_num_hashed_tokens: + evicted_block = block + break + + assert evicted_block is not None + + del self.free_table[evicted_block.block_hash] + + evicted_block.computed = False + return evicted_block + + def add(self, block: PhysicalTokenBlock): + self.free_table[block.block_hash] = block + + def remove(self, block_hash: int) -> PhysicalTokenBlock: + if block_hash not in self.free_table: + raise ValueError( + "Attempting to remove block that's not in the evictor") + block: PhysicalTokenBlock = self.free_table[block_hash] + del self.free_table[block_hash] + return block + + @property + def num_blocks(self) -> int: + return len(self.free_table) + + +class RandomEvictor(Evictor): + """Evicts in a first-in-first-out order""" + + def __init__(self): + self.free_table: Dict[int, PhysicalTokenBlock] = {} + + def __contains__(self, block_hash: int) -> bool: + return block_hash in self.free_table + + def evict(self) -> PhysicalTokenBlock: + if len(self.free_table) == 0: + raise ValueError("No usable cache memory left") + evicted_block = next(iter(self.free_table.values())) + evicted_block.computed = False + del self.free_table[evicted_block.block_hash] + return evicted_block + + def add(self, block: PhysicalTokenBlock): + self.free_table[block.block_hash] = block + + def remove(self, block_hash: int) -> PhysicalTokenBlock: + if block_hash not in self.free_table: + raise ValueError( + "Attempting to remove block that's not in the evictor") + block: PhysicalTokenBlock = self.free_table[block_hash] + del self.free_table[block_hash] + return block + + @property + def num_blocks(self) -> int: + return len(self.free_table) + + +def make_evictor(eviction_policy: EvictionPolicy) -> Evictor: + if eviction_policy == EvictionPolicy.LRU: + return LRUEvictor() + elif eviction_policy == EvictionPolicy.FIFO: + return RandomEvictor() + else: + raise ValueError(f"Unknown cache eviction policy: {eviction_policy}") diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 5e7cc3091d775..1ae58f525b0fb 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -10,7 +10,6 @@ from vllm.logger import init_logger from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) -from vllm.prefix import PrefixPool logger = init_logger(__name__) @@ -95,10 +94,8 @@ def __init__( block_size=self.cache_config.block_size, num_gpu_blocks=self.cache_config.num_gpu_blocks, num_cpu_blocks=self.cache_config.num_cpu_blocks, - sliding_window=self.cache_config.sliding_window) - - # Create the prefix pool to cache the prefixes. - self.prefix_pool = PrefixPool(self.cache_config.block_size) + sliding_window=self.cache_config.sliding_window, + enable_caching=self.cache_config.enable_prefix_caching) # Sequence groups in the WAITING state. self.waiting: Deque[SequenceGroup] = deque() @@ -374,10 +371,12 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_data: Dict[int, SequenceData] = {} block_tables: Dict[int, List[int]] = {} + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq_id = seq.seq_id seq_data[seq_id] = seq.data block_tables[seq_id] = self.block_manager.get_block_table(seq) + self.block_manager.access_all_blocks_in_seq(seq, now) seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, @@ -386,7 +385,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: sampling_params=seq_group.sampling_params, block_tables=block_tables, lora_request=seq_group.lora_request, - prefix=seq_group.prefix, + computed_block_nums=self.block_manager. + get_common_computed_block_ids(seq_group), state=seq_group.state, ) seq_group_metadata_list.append(seq_group_metadata) @@ -496,3 +496,6 @@ def _swap_out( blocks_to_swap_out.update(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq.status = SequenceStatus.SWAPPED + + def mark_blocks_as_computed(self, seq_group: SequenceGroup): + self.block_manager.mark_blocks_as_computed(seq_group) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c01e7311fb89a..0349c3a6636c7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -25,6 +25,7 @@ class EngineArgs: tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None block_size: int = 16 + enable_prefix_caching: bool = False swap_space: int = 4 # GiB gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None @@ -173,6 +174,11 @@ def add_cli_args( default=EngineArgs.block_size, choices=[8, 16, 32, 128], help='token block size') + + parser.add_argument('--enable-prefix-caching', + action='store_true', + help='Enables automatic prefix caching') + parser.add_argument('--seed', type=int, default=EngineArgs.seed, @@ -293,7 +299,8 @@ def create_engine_configs( cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, - model_config.get_sliding_window()) + model_config.get_sliding_window(), + self.enable_prefix_caching) parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index daa6419cdad3b..9e52d20ca4980 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -225,7 +225,6 @@ async def add_request_async( prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - prefix_pos: Optional[int] = None, ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " @@ -245,7 +244,6 @@ async def add_request_async( sampling_params=sampling_params, arrival_time=arrival_time, lora_request=lora_request, - prefix_pos=prefix_pos, ) async def _run_workers_async( @@ -422,7 +420,6 @@ async def add_request( prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - prefix_pos: Optional[int] = None, ) -> AsyncStream: if self.log_requests: shortened_prompt = prompt @@ -435,7 +432,6 @@ async def add_request( max_log_len] logger.info(f"Received request {request_id}: " f"prompt: {shortened_prompt!r}, " - f"prefix_pos: {prefix_pos}," f"sampling_params: {sampling_params}, " f"prompt_token_ids: {shortened_token_ids}, " f"lora_request: {lora_request}.") @@ -472,8 +468,7 @@ async def add_request( sampling_params=sampling_params, prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, - lora_request=lora_request, - prefix_pos=prefix_pos) + lora_request=lora_request) return stream @@ -484,7 +479,6 @@ async def generate( request_id: str, prompt_token_ids: Optional[List[int]] = None, lora_request: Optional[LoRARequest] = None, - prefix_pos: Optional[int] = None, ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -500,11 +494,6 @@ async def generate( prompt_token_ids: The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs. lora_request: LoRA request to use for generation, if any. - prefix_pos: If not None, we use the given position as the prefix - position for each prompt. We will cache the prefix's KV - cache and reuse it for the next request with the same prefix. - This is an experimental feature, and may be replaced with - automatic prefix caching in the future. Yields: The output `RequestOutput` objects from the LLMEngine for the @@ -565,7 +554,6 @@ async def generate( prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, lora_request=lora_request, - prefix_pos=prefix_pos, ) async for request_output in stream: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index df4858a696530..e84fda5640e4d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -415,7 +415,6 @@ def add_request( prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - prefix_pos: Optional[int] = None, ) -> None: """Add a request to the engine's request pool. @@ -432,11 +431,6 @@ def add_request( use the tokenizer to convert the prompts to token IDs. arrival_time: The arrival time of the request. If None, we use the current monotonic time. - prefix_pos: If not None, we use the given position as the prefix - position for each prompt. We will cache the prefix's KV - cache and reuse it for the next request with the same prefix. - This is an experimental feature, and may be replaced with - automatic prefix caching in the future. Details: - Set arrival_time to the current time if it is None. @@ -479,18 +473,13 @@ def add_request( seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, lora_request) - # Check whether the input specifies prefix - prefix = self.scheduler.prefix_pool.add_or_get_prefix( - prompt_token_ids[:prefix_pos], lora_request.lora_int_id - if lora_request else 0) if prefix_pos is not None else None - # Defensive copy of SamplingParams, which are used by the sampler, # this doesn't deep-copy LogitsProcessor objects sampling_params = sampling_params.clone() # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time, lora_request, prefix) + arrival_time, lora_request) # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) @@ -752,6 +741,13 @@ def _process_model_outputs( now = time.time() # Update the scheduled sequence groups with the model outputs. scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups + + # If prefix caching is enabled, mark all blocks in the sequence groups + # as completed so that future requests don't attempt to recompute them + if self.cache_config.enable_prefix_caching: + for seq_group in scheduled_seq_groups: + self.scheduler.mark_blocks_as_computed(seq_group) + for seq_group, outputs in zip(scheduled_seq_groups, output): self._process_sequence_group_outputs(seq_group, outputs) @@ -768,12 +764,6 @@ def _process_model_outputs( request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) - # Update prefix state, now all the uncomputed prefixes are computed. - for seq_group in scheduled_seq_groups: - if (seq_group.prefix is not None and seq_group.prefix.allocated - and not seq_group.prefix.computed): - seq_group.prefix.computed = True - # Log stats. if self.log_stats: self.stat_logger.log(self._get_stats(scheduler_outputs)) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index e7af2c6db5e4c..1eb4ab8b06b64 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -39,15 +39,11 @@ async def generate(request: Request) -> Response: """ request_dict = await request.json() prompt = request_dict.pop("prompt") - prefix_pos = request_dict.pop("prefix_pos", None) stream = request_dict.pop("stream", False) sampling_params = SamplingParams(**request_dict) request_id = random_uuid() - results_generator = engine.generate(prompt, - sampling_params, - request_id, - prefix_pos=prefix_pos) + results_generator = engine.generate(prompt, sampling_params, request_id) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index fc82018d18eb6..62f1d172377f6 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -124,7 +124,6 @@ def generate( prompts: Optional[Union[str, List[str]]] = None, sampling_params: Optional[SamplingParams] = None, prompt_token_ids: Optional[List[List[int]]] = None, - prefix_pos: Optional[Union[int, List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, ) -> List[RequestOutput]: @@ -140,11 +139,6 @@ def generate( None, we use the default sampling parameters. prompt_token_ids: A list of token IDs for the prompts. If None, we use the tokenizer to convert the prompts to token IDs. - prefix_pos: If not None, we use the given position as the prefix - position for each prompt. We will cache the prefix's KV - cache and reuse it for the next request with the same prefix. - This is an experimental feature, and may be replaced with - automatic prefix caching in the future. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. @@ -171,14 +165,12 @@ def generate( prompt_token_ids) for i in range(num_requests): prompt = prompts[i] if prompts is not None else None - prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None token_ids = None if prompt_token_ids is None else prompt_token_ids[ i] self._add_request(prompt, sampling_params, token_ids, - lora_request=lora_request, - prefix_pos=prefix_pos_i) + lora_request=lora_request) return self._run_engine(use_tqdm) def _add_request( @@ -187,15 +179,13 @@ def _add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]], lora_request: Optional[LoRARequest] = None, - prefix_pos: Optional[int] = None, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request(request_id, prompt, sampling_params, prompt_token_ids, - lora_request=lora_request, - prefix_pos=prefix_pos) + lora_request=lora_request) def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. diff --git a/vllm/prefix.py b/vllm/prefix.py deleted file mode 100644 index 5b6e8e4b92be6..0000000000000 --- a/vllm/prefix.py +++ /dev/null @@ -1,87 +0,0 @@ -from typing import Dict, List, Sequence, Tuple, Optional - -from vllm.block import BlockTable - - -class Prefix: - """Data and states associated with a prefix of prompt tokens for multiple - sequence groups. - - NOTE: This feature is experimental and may be replaced with automatic - prefix caching in the future. - - Args: - token_ids: The token ids of the prefix. - block_size: The block size of the executed model. - """ - - def __init__( - self, - token_ids: Sequence[int], - block_size: int, - ) -> None: - self.token_ids = tuple(token_ids) - self.block_size = block_size - self.length = len(token_ids) - self.hash = hash(token_ids) - assert self.length % block_size == 0 - self.block_table: Optional[BlockTable] = None - self.computed = False - - @property - def allocated(self) -> bool: - return self.block_table is not None - - def get_num_blocks(self) -> int: - return self.length // self.block_size - - def get_block_numbers(self) -> List[int]: - return [block.block_number for block in self.block_table] - - def get_length(self) -> int: - return self.length - - def __hash__(self) -> int: - return self.hash - - def set_block_table(self, block_table: BlockTable) -> None: - self.block_table = block_table.copy() - - -class PrefixPool: - """Manages all the prompt prefixes. - - NOTE: This feature is experimental and may be replaced with automatic - prefix caching in the future. - - Args: - block_size: The block size of the executed model. - - Attributes: - prefixes: A list of all the prefixes. - block_size: The block size of the executed model. - """ - - def __init__( - self, - block_size: int, - ) -> None: - # TODO(zhuohan): Add a capacity limit to the prefix pool. - self.prefixes: Dict[int, Prefix] = {} - self.block_size = block_size - - def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]: - new_length = len(token_ids) // self.block_size * self.block_size - return tuple(token_ids[:new_length]) - - def add_or_get_prefix(self, token_ids: Sequence[int], - lora_int_id: int) -> Optional[Prefix]: - token_ids = self._truncate_token_ids(token_ids) - if len(token_ids) == 0: - # Prefix is empty. - return None - prefix = Prefix(token_ids, self.block_size) - prefix_hash = hash((prefix, lora_int_id)) - if prefix_hash not in self.prefixes: - self.prefixes[prefix_hash] = prefix - return self.prefixes[prefix_hash] diff --git a/vllm/sequence.py b/vllm/sequence.py index 040e9756e15c6..122960035e505 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,7 +5,6 @@ from typing import Dict, List, Optional, Union from vllm.block import LogicalTokenBlock -from vllm.prefix import Prefix from vllm.sampling_params import SamplingParams from vllm.lora.request import LoRARequest @@ -161,6 +160,16 @@ def __init__( def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + # TODO The current hashing function is O(L^2). We should optimize this in + # the future. + def hash_of_block(self, logical_idx: int) -> int: + # Compute the number of tokens in the sequence + num_tokens = self.num_hashed_tokens_of_block(logical_idx) + return hash(tuple(self.data.get_token_ids()[0:num_tokens])) + + def num_hashed_tokens_of_block(self, logical_idx: int): + return logical_idx * self.block_size + self.block_size + def _append_logical_block(self) -> None: block = LogicalTokenBlock( block_number=len(self.logical_token_blocks), @@ -265,7 +274,6 @@ class SequenceGroup: sampling_params: The sampling parameters used to generate the outputs. arrival_time: The arrival time of the request. lora_request: LoRA request. - prefix: The prefix of the prompt of the sequence group. """ def __init__( @@ -275,7 +283,6 @@ def __init__( sampling_params: SamplingParams, arrival_time: float, lora_request: Optional[LoRARequest] = None, - prefix: Optional[Prefix] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -286,7 +293,6 @@ def __init__( first_token_time=None, time_in_queue=None) self.lora_request = lora_request - self.prefix: Optional[Prefix] = prefix self.prompt_logprobs: Optional[PromptLogprobs] = None self.state = SequenceGroupState() @@ -302,6 +308,10 @@ def prompt_token_ids(self) -> List[int]: # We use the prompt of an arbitrary sequence. return next(iter(self.seqs_dict.values())).data.prompt_token_ids + @property + def block_size(self) -> int: + return next(iter(self.seqs_dict.values())).block_size + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -408,7 +418,6 @@ class SequenceGroupMetadata: numbers) state: Internal state tied to this sequence group. lora_request: LoRA request. - prefix: The prefix of the prompt of the sequence group. """ def __init__( @@ -419,7 +428,7 @@ def __init__( sampling_params: SamplingParams, block_tables: Dict[int, List[int]], lora_request: Optional[LoRARequest] = None, - prefix: Optional[Prefix] = None, + computed_block_nums: Optional[List[int]] = None, state: Optional[SequenceGroupState] = None, ) -> None: self.request_id = request_id @@ -428,7 +437,7 @@ def __init__( self.sampling_params = sampling_params self.block_tables = block_tables self.lora_request = lora_request - self.prefix = prefix + self.computed_block_nums = computed_block_nums self.state = SequenceGroupState() if state is None else state @property diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index efe570778fb43..aff8ebc903623 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -145,33 +145,37 @@ def _prepare_prompt( prompt_tokens = seq_data.get_token_ids() prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) - prefix_len = 0 - prefix = seq_group_metadata.prefix - if prefix is not None and prefix.computed: - prefix_len = prefix.get_length() - prompt_tokens = prompt_tokens[prefix_len:] - prefix_block_tables.append(prefix.get_block_numbers()) + computed_len = 0 + + # NOTE: This only works for oooooooxxx style attention. + computed_block_nums = seq_group_metadata.computed_block_nums + if computed_block_nums is not None and len( + computed_block_nums) > 0 and self.sliding_window is None: + # Prefix is not supported with sliding_window + computed_len = len(computed_block_nums) * self.block_size + prompt_tokens = prompt_tokens[computed_len:] + prefix_block_tables.append(computed_block_nums) else: prefix_block_tables.append([]) # actual prompt lens - context_lens.append(prefix_len) - subquery_lens.append(prompt_len - prefix_len) + context_lens.append(computed_len) + subquery_lens.append(prompt_len - computed_len) input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. input_positions.append( - list(range(prefix_len, prefix_len + len(prompt_tokens)))) + list(range(computed_len, computed_len + len(prompt_tokens)))) lora_id = seq_group_metadata.lora_int_id if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - lora_index_mapping.append([lora_id] * (prompt_len - prefix_len)) + lora_index_mapping.append([lora_id] * (prompt_len - computed_len)) lora_prompt_mapping.extend( [lora_id] * - (prompt_len - prefix_len + (prompt_len - computed_len if seq_group_metadata.sampling_params.prompt_logprobs else 1)) if seq_group_metadata.block_tables is None: @@ -190,11 +194,11 @@ def _prepare_prompt( # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. start_idx = 0 if self.sliding_window is not None: - assert prefix_len == 0, ( + assert computed_len == 0, ( "Prefix caching is currently not supported with " "sliding window attention") start_idx = max(0, prompt_len - self.sliding_window) - for i in range(prefix_len, prompt_len): + for i in range(computed_len, prompt_len): if i < start_idx: slot_mapping[-1].append(_PAD_SLOT_ID) continue