From 1a7eb7da6157541ed7867c9aff94231695f2cee9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 10 Mar 2023 09:58:21 -0800 Subject: [PATCH] Support beam search & parallel generation (#7) --- cacheflow/block.py | 4 + cacheflow/master/frontend.py | 33 +++- cacheflow/master/scheduler.py | 91 ++++++---- cacheflow/models/__init__.py | 4 +- cacheflow/models/input_metadata.py | 18 +- cacheflow/models/model_utils.py | 10 ++ cacheflow/models/opt.py | 3 +- cacheflow/models/sample.py | 275 +++++++++++++++++++++++++++-- cacheflow/sampling_params.py | 73 ++++++-- cacheflow/sequence.py | 78 +++++++- cacheflow/worker/cache_engine.py | 35 +++- cacheflow/worker/controller.py | 17 +- cacheflow/worker/worker.py | 119 ++++++++----- csrc/cache.cpp | 16 +- csrc/cache_kernels.cu | 32 +++- server.py | 17 +- 16 files changed, 662 insertions(+), 163 deletions(-) diff --git a/cacheflow/block.py b/cacheflow/block.py index 4d5dc2ef12a54..df8d46ab58c07 100644 --- a/cacheflow/block.py +++ b/cacheflow/block.py @@ -35,6 +35,10 @@ def append(self, token_ids: List[int]) -> None: def get_token_ids(self) -> List[int]: return self.token_ids[:self.num_tokens] + def get_last_token_id(self) -> int: + assert self.num_tokens > 0 + return self.token_ids[self.num_tokens - 1] + class PhysicalTokenBlock: diff --git a/cacheflow/master/frontend.py b/cacheflow/master/frontend.py index 23ae2723dc091..cfa17684fd56a 100644 --- a/cacheflow/master/frontend.py +++ b/cacheflow/master/frontend.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Set, Tuple from transformers import AutoTokenizer @@ -25,12 +25,35 @@ def __init__( def query( self, prompt: str, - sampling_params: Optional[SamplingParams] = None, + n: int = 1, + temperature: float = 1.0, + top_p: float = 1.0, + use_beam_search: bool = False, + stop_token_ids: Set[int] = set(), + max_num_steps: int = 16, # From OpenAI API. + num_logprobs: int = 0, + context_window_size: Optional[int] = None, ) -> None: - if sampling_params is None: - sampling_params = SamplingParams() - token_ids: List[int] = self.tokenizer.encode(prompt) + # Stop when we see an EOS token. + stop_token_ids.add(self.tokenizer.eos_token_id) + sampling_params = SamplingParams( + n=n, + temperature=temperature, + top_p=top_p, + use_beam_search=use_beam_search, + stop_token_ids=stop_token_ids, + max_num_steps=max_num_steps, + num_logprobs=num_logprobs, + context_window_size=context_window_size, + ) + token_ids = self.tokenizer.encode(prompt) + self._add_query(token_ids, sampling_params) + def _add_query( + self, + token_ids: List[int], + sampling_params: SamplingParams, + ) -> None: seqs: List[Sequence] = [] for _ in range(sampling_params.n): seq_id = next(self.seq_counter) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 7b12500203e35..7f2ca1455fc43 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -1,10 +1,12 @@ -from typing import Dict, List, Tuple +from typing import Dict, List from cacheflow.master.block_manager import BlockSpaceManager from cacheflow.master.frontend import Frontend from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import Sequence from cacheflow.sequence import SequenceGroup +from cacheflow.sequence import SequenceGroupInputs +from cacheflow.sequence import SequenceOutputs from cacheflow.sequence import SequenceStatus _MAX_NUM_BATCHED_TOKENS = 2048 @@ -66,7 +68,7 @@ def _allocate(self, seq_group: SequenceGroup) -> None: def _append( self, seq_group: SequenceGroup, - blocks_to_copy: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], ) -> None: for seq in seq_group.seqs: if seq.status == SequenceStatus.FINISHED: @@ -74,7 +76,10 @@ def _append( ret = self.block_manager.append(seq) if ret is not None: src_block, dst_block = ret - blocks_to_copy[src_block] = dst_block + if src_block in blocks_to_copy: + blocks_to_copy[src_block].append(dst_block) + else: + blocks_to_copy[src_block] = [dst_block] def _swap_in( self, @@ -83,9 +88,8 @@ def _swap_in( ) -> None: mapping = self.block_manager.swap_in(seq_group) blocks_to_swap_in.update(mapping) - for seq in seq_group.seqs: - if seq.status == SequenceStatus.SWAPPED: - seq.status = SequenceStatus.RUNNING + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + seq.status = SequenceStatus.RUNNING self.running.append(seq_group) def _swap_out( @@ -96,16 +100,15 @@ def _swap_out( assert self.block_manager.can_swap_out(seq_group) mapping = self.block_manager.swap_out(seq_group) blocks_to_swap_out.update(mapping) - for seq in seq_group.seqs: - if seq.status == SequenceStatus.RUNNING: - seq.status = SequenceStatus.SWAPPED + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq.status = SequenceStatus.SWAPPED self.swapped.append(seq_group) def step(self) -> None: # Blocks that need to be swaped or copied before model execution. blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {} - blocks_to_copy: Dict[int, int] = {} + blocks_to_copy: Dict[int, List[int]] = {} # 1. Reserve new slots for the running sequences. # NOTE: Here we implicitly assume FCFS scheduling. @@ -143,6 +146,10 @@ def step(self) -> None: # All swapped sequences are swapped in. self.swapped.clear() + # Ensure that swap-in and swap-out never happen at the same timestep. + if blocks_to_swap_in: + assert not blocks_to_swap_out + num_batched_tokens = sum( seq_group.num_seqs(status=SequenceStatus.RUNNING) for seq_group in self.running @@ -152,7 +159,6 @@ def step(self) -> None: # NOTE: Here we implicitly assume FCFS scheduling. # TODO(woosuk): Add a batching policy to control the batch size. if not self.swapped: - # FIXME(woosuk): Acquire a lock to protect pending. self._fetch_inputs() for i, seq_group in enumerate(self.pending): num_prompt_tokens = seq_group.seqs[0].get_len() @@ -168,39 +174,45 @@ def step(self) -> None: else: self.pending.clear() - # Ensure that swap-in and swap-out never happen at the same timestep. - if blocks_to_swap_in: - assert not blocks_to_swap_out - # 4. Create input data structures. - prompt_tokens: Dict[int, List[int]] = {} - generation_tokens: Dict[int, int] = {} - context_lens: Dict[int, int] = {} - block_tables: Dict[int, List[int]] = {} + input_seq_groups: List[SequenceGroupInputs] = [] for seq_group in self.running: group_id = seq_group.group_id num_steps = self.num_steps[group_id] + # NOTE(woosuk): We assume that the number of steps is 0 # for the prompt sequences. is_prompt = num_steps == 0 - for seq in seq_group.seqs: - if seq.status != SequenceStatus.RUNNING: - continue + input_tokens: Dict[int, List[int]] = {} + seq_logprobs: Dict[int, float] = {} + block_tables: Dict[int, List[int]] = {} + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq_id = seq.seq_id block_tables[seq_id] = self.block_manager.get_block_table(seq) if is_prompt: - prompt_tokens[seq_id] = seq.get_token_ids() + input_tokens[seq_id] = seq.get_token_ids() else: - generation_tokens[seq_id] = seq.get_token_ids()[-1] - context_lens[seq_id] = seq.get_len() + input_tokens[seq_id] = [seq.get_last_token_id()] + seq_logprobs[seq_id] = seq.cumulative_logprobs + # NOTE(woosuk): Sequences in the same group have the same + # sequence length + seq_len = seq.get_len() + + input_seq_group = SequenceGroupInputs( + group_id=group_id, + is_prompt=is_prompt, + input_tokens=input_tokens, + context_len=seq_len, + seq_logprobs=seq_logprobs, + sampling_params=self.sampling_params[group_id], + block_tables=block_tables, + ) + input_seq_groups.append(input_seq_group) # 5. Execute the first stage of the pipeline. self.controllers[0].execute_stage( - prompt_tokens, - generation_tokens, - context_lens, - block_tables, + input_seq_groups, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy, @@ -208,7 +220,7 @@ def step(self) -> None: def post_step( self, - next_tokens: Dict[int, Tuple[int, int]], + seq_outputs: Dict[int, SequenceOutputs], ) -> None: # Update the running sequences and free blocks. for seq_group in self.running: @@ -216,25 +228,32 @@ def post_step( self.num_steps[group_id] += 1 stop_token_ids = self.sampling_params[group_id].stop_token_ids + # Process beam search results before processing the next tokens. for seq in seq_group.seqs: if seq.status == SequenceStatus.FINISHED: continue - parent_seq_id, next_token = next_tokens[seq.seq_id] - if seq.seq_id != parent_seq_id: + output = seq_outputs[seq.seq_id] + if seq.seq_id != output.parent_seq_id: # The sequence is a fork of the parent sequence (beam search). # Free the current sequence. self.block_manager.free(seq) # Fork the parent sequence. - parent_seq = seq_group.find(parent_seq_id) - seq.logical_token_blocks = parent_seq.logical_token_blocks.copy() + parent_seq = seq_group.find(output.parent_seq_id) + parent_seq.fork(seq) self.block_manager.fork(parent_seq, seq) + # Process the next tokens. + for seq in seq_group.seqs: + if seq.status == SequenceStatus.FINISHED: + continue + # Append a new token to the sequence. - seq.append([next_token]) + output = seq_outputs[seq.seq_id] + seq.append(output.output_token, output.logprobs) # Check if the sequence has generated a stop token. - if next_token in stop_token_ids: + if output.output_token in stop_token_ids: self._free_seq(seq) continue diff --git a/cacheflow/models/__init__.py b/cacheflow/models/__init__.py index 498101b53fdd7..67dbd5627cbb3 100644 --- a/cacheflow/models/__init__.py +++ b/cacheflow/models/__init__.py @@ -1,8 +1,10 @@ from cacheflow.models.input_metadata import InputMetadata from cacheflow.models.model_utils import get_model +from cacheflow.models.model_utils import set_seed __all__ = [ - 'get_model', 'InputMetadata', + 'get_model', + 'set_seed' ] diff --git a/cacheflow/models/input_metadata.py b/cacheflow/models/input_metadata.py index 77f25054e38a6..e4787d4a82711 100644 --- a/cacheflow/models/input_metadata.py +++ b/cacheflow/models/input_metadata.py @@ -1,21 +1,24 @@ -from typing import List +from typing import List, Dict, Tuple import torch +from cacheflow.sampling_params import SamplingParams + class InputMetadata: def __init__( self, - seq_ids: List[int], + seq_groups: List[Tuple[List[int], SamplingParams]], + seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs. prompt_lens: List[int], slot_mapping: torch.Tensor, context_lens: torch.Tensor, - # FIXME: Rename max_context_len: int, block_tables: torch.Tensor, ) -> None: - self.seq_ids = seq_ids + self.seq_groups = seq_groups + self.seq_logprobs = seq_logprobs self.prompt_lens = prompt_lens self.slot_mapping = slot_mapping self.context_lens = context_lens @@ -23,19 +26,20 @@ def __init__( self.block_tables = block_tables self.num_prompts = len(prompt_lens) + self.num_prompt_tokens = sum(prompt_lens) self.num_generation_tokens = context_lens.shape[0] self.num_valid_tokens = slot_mapping.shape[0] if block_tables.numel() > 0: self.max_num_blocks_per_seq = block_tables.shape[1] else: self.max_num_blocks_per_seq = 0 - assert self.num_generation_tokens == block_tables.shape[0] - assert self.num_prompts + self.num_generation_tokens == len(seq_ids) + assert block_tables.shape[0] == self.num_generation_tokens + assert context_lens.shape[0] == self.num_generation_tokens def __repr__(self) -> str: return (f'InputMetadata(' - f'seq_ids={self.seq_ids}, ' f'num_prompts={self.num_prompts}, ' + f'num_prompt_tokens={self.num_prompt_tokens}, ' f'num_generation_tokens={self.num_generation_tokens}, ' f'num_valid_tokens={self.num_valid_tokens}, ' f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, ' diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index 0e7e4d3b2dd09..d26fd8c46a1dc 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -1,5 +1,7 @@ +import random from typing import Union +import numpy as np import torch import torch.nn as nn @@ -30,3 +32,11 @@ def get_model( model = hf_model.from_pretrained(model_name, torch_dtype=torch_dtype) return model.eval() raise ValueError(f'Invalid model name: {model_name}') + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) diff --git a/cacheflow/models/opt.py b/cacheflow/models/opt.py index 14d38d4073195..5790335890599 100644 --- a/cacheflow/models/opt.py +++ b/cacheflow/models/opt.py @@ -9,6 +9,7 @@ from cacheflow.models import InputMetadata from cacheflow.models.attention import OPTCacheFlowAttention from cacheflow.models.sample import Sampler +from cacheflow.sequence import SequenceOutputs KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -261,7 +262,7 @@ def forward( kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> Dict[int, Tuple[int, int]]: + ) -> Dict[int, SequenceOutputs]: hidden_states = self.model( input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler( diff --git a/cacheflow/models/sample.py b/cacheflow/models/sample.py index 5c984d39781f1..7bdc6a42771a9 100644 --- a/cacheflow/models/sample.py +++ b/cacheflow/models/sample.py @@ -4,6 +4,8 @@ import torch.nn as nn from cacheflow.models import InputMetadata +from cacheflow.sampling_params import SamplingParams +from cacheflow.sequence import SequenceOutputs class Sampler(nn.Module): @@ -16,27 +18,266 @@ def forward( embedding: torch.Tensor, hidden_states: torch.Tensor, input_metadata: InputMetadata, - ) -> Dict[int, Tuple[int, int]]: - # Get the hidden states of the last tokens. - start_idx = 0 - last_token_indicies: List[int] = [] - for prompt_len in input_metadata.prompt_lens: - last_token_indicies.append(start_idx + prompt_len - 1) - start_idx += prompt_len - last_token_indicies.extend( - range(start_idx, start_idx + input_metadata.num_generation_tokens)) - hidden_states = hidden_states[last_token_indicies] + ) -> Dict[int, SequenceOutputs]: + # Get the hidden states that we use for sampling. + hidden_states = _prune_hidden_states(hidden_states, input_metadata) # Get the logits for the next tokens. logits = torch.matmul(hidden_states, embedding.t()) + # Apply temperature scaling. + temperatures = _get_temperatures(input_metadata) + assert len(temperatures) == logits.shape[0] + if any(t != 1.0 for t in temperatures): + t = torch.tensor( + temperatures, dtype=logits.dtype, device=logits.device) + # Use in-place division to avoid creating a new tensor. + logits.div_(t.unsqueeze(dim=1)) + + # Compute the probabilities. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # Compute the log probabilities (before applying top-p). + logprobs = torch.log(probs) + + # Apply top-p truncation. + top_ps = _get_top_ps(input_metadata) + assert len(top_ps) == probs.shape[0] + if any(p < 1.0 for p in top_ps): + p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device) + probs = _apply_top_p(probs, p) + # Sample the next tokens. - # TODO(woosuk): Implement other sampling methods. - next_token_ids = torch.argmax(logits, dim=-1) + return _sample(probs, logprobs, input_metadata) + + +def _prune_hidden_states( + hidden_states: torch.Tensor, + input_metadata: InputMetadata, +) -> torch.Tensor: + start_idx = 0 + last_token_indicies: List[int] = [] + for prompt_len in input_metadata.prompt_lens: + last_token_indicies.append(start_idx + prompt_len - 1) + start_idx += prompt_len + last_token_indicies.extend( + range(start_idx, start_idx + input_metadata.num_generation_tokens)) + return hidden_states[last_token_indicies] + + +def _get_temperatures( + input_metadata: InputMetadata, +) -> List[float]: + # Collect the temperatures for the logits. + temperatures: List[float] = [] + for i, seq_group in enumerate(input_metadata.seq_groups): + seq_ids, sampling_params = seq_group + temperature = sampling_params.temperature + if temperature == 0.0: + # NOTE: Zero temperature means deterministic sampling + # (i.e., greedy sampling or beam search). + # Set the temperature to 1 to avoid division by zero. + temperature = 1.0 + + if i < input_metadata.num_prompts: + # A prompt input. + temperatures.append(temperature) + else: + # A generation token. + temperatures += [temperature] * len(seq_ids) + return temperatures + + +def _get_top_ps( + input_metadata: InputMetadata, +) -> List[float]: + top_ps: List[float] = [] + for i, seq_group in enumerate(input_metadata.seq_groups): + seq_ids, sampling_params = seq_group + if i < input_metadata.num_prompts: + # A prompt input. + top_ps.append(sampling_params.top_p) + else: + # A generation token. + top_ps += [sampling_params.top_p] * len(seq_ids) + return top_ps + + +def _apply_top_p( + probs: torch.Tensor, + p: torch.Tensor, +) -> torch.Tensor: + # TODO(woosuk): Optimize. + probs_sort, probs_idx = probs.sort(dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + probs = torch.gather( + probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1)) + return probs + + +def _get_topk_logprobs( + logprobs: torch.Tensor, + num_logprobs: int, +) -> Dict[int, float]: + if num_logprobs == 0: + return {} + + topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs) + if num_logprobs == 1: + topk_logprobs = [topk_logprobs.item()] + topk_ids = [topk_ids.item()] + else: + topk_logprobs = topk_logprobs.tolist() + topk_ids = topk_ids.tolist() + + token_to_logprob: Dict[int, float] = {} + for token_id, logprob in zip(topk_ids, topk_logprobs): + token_to_logprob[token_id] = logprob + return token_to_logprob + + +def _sample_from_prompt( + prob: torch.Tensor, + sampling_params: SamplingParams, +) -> List[int]: + if sampling_params.use_beam_search: + # Beam search. + beam_width = sampling_params.n + _, next_token_ids = torch.topk(prob, beam_width) + next_token_ids = next_token_ids.tolist() + elif sampling_params.temperature == 0.0: + # Greedy sampling. + assert sampling_params.n == 1 + next_token_id = torch.argmax(prob) + next_token_ids = [next_token_id.item()] + else: + # Neucleus sampling. + # Sample n tokens for the prompt. + n = sampling_params.n + next_token_ids = torch.multinomial( + prob, num_samples=n, replacement=True) next_token_ids = next_token_ids.tolist() + return next_token_ids + + +def _sample_from_generation_tokens( + seq_ids: List[int], + probs: torch.Tensor, + logprobs: torch.Tensor, + seq_logprobs: List[float], + sampling_params: SamplingParams, +) -> Tuple[List[int], List[int]]: + # NOTE(woosuk): sampling_params.n can be greater than + # len(seq_ids) because some sequences in the group might have + # been already terminated. + if sampling_params.use_beam_search: + # Beam search. + # Add cumulative logprobs for the sequences in the group. + seq_logprobs = torch.tensor( + seq_logprobs, dtype=torch.float, device=logprobs.device) + logprobs = logprobs + seq_logprobs.unsqueeze(dim=1) + + vocab_size = logprobs.size(-1) + beam_width = len(seq_ids) + _, topk_ids = torch.topk(logprobs.flatten(), beam_width) + seq_idx = torch.div(topk_ids, vocab_size, rounding_mode='floor').tolist() + beam_seq_ids = [seq_ids[i] for i in seq_idx] + token_ids = (topk_ids % vocab_size).tolist() + + beam_outputs: Dict[int, Tuple[int, int]] = {} + outstanding_beams: List[Tuple[int, int]] = [] + # If a beam survives, continue with it. + for seq_id, token_id in zip(beam_seq_ids, token_ids): + if seq_id not in beam_outputs: + beam_outputs[seq_id] = (seq_id, token_id) + else: + outstanding_beams.append((seq_id, token_id)) + + # If a beam is discarded, fork another beam. + for seq_id in seq_ids: + if seq_id not in beam_outputs: + beam_outputs[seq_id] = outstanding_beams.pop() + assert not outstanding_beams + + parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids] + next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids] + elif sampling_params.temperature == 0.0: + # Greedy sampling. + assert len(seq_ids) == 1 + next_token_id = torch.argmax(probs, dim=-1) + next_token_ids = [next_token_id.item()] + parent_seq_ids = seq_ids + else: + # Neucleus sampling. + # Sample 1 token for each sequence in the group. + next_token_ids = torch.multinomial( + probs, num_samples=1, replacement=True) + next_token_ids = next_token_ids.squeeze(dim=-1).tolist() + parent_seq_ids = seq_ids + return parent_seq_ids, next_token_ids + + +def _sample( + probs: torch.Tensor, + logprobs: torch.Tensor, + input_metadata: InputMetadata, +) -> Dict[int, SequenceOutputs]: + seq_outputs: Dict[int, SequenceOutputs] = {} + + # TODO(woosuk): Optimize. + idx = 0 + for i, seq_group in enumerate(input_metadata.seq_groups): + seq_ids, sampling_params = seq_group + if i < input_metadata.num_prompts: + # Generate the next tokens for a prompt input. + assert len(seq_ids) == sampling_params.n + prob = probs[idx] + logprob = logprobs[idx] + idx += 1 + + # Sample the next tokens. + next_token_ids = _sample_from_prompt(prob, sampling_params) + # Get top-k log probabilities for the next tokens. + next_logprobs = _get_topk_logprobs( + logprob, sampling_params.num_logprobs) + + # Build the output. + for seq_id, next_token_id in zip(seq_ids, next_token_ids): + output_logprobs = next_logprobs.copy() + output_logprobs[next_token_id] = logprob[next_token_id].item() + seq_outputs[seq_id] = SequenceOutputs( + seq_id, seq_id, next_token_id, output_logprobs) + else: + # Generate the next tokens for generation tokens. + prob = probs[idx:idx + len(seq_ids)] + logprob = logprobs[idx:idx + len(seq_ids)] + idx += len(seq_ids) + + # Sample the next tokens. + seq_logprobs = [ + input_metadata.seq_logprobs[seq_id] for seq_id in seq_ids] + parent_seq_ids, next_token_ids = _sample_from_generation_tokens( + seq_ids, prob, logprob, seq_logprobs, sampling_params) + + # Get top-k log probabilities for the next tokens. + next_logprobs: Dict[int, Dict[int, float]] = {} + for i, seq_id in enumerate(seq_ids): + next_logprobs[seq_id] = _get_topk_logprobs( + logprob[i], sampling_params.num_logprobs) + + # Build the output. + for seq_id, parent_seq_id, next_token_id in zip( + seq_ids, parent_seq_ids, next_token_ids): + i = seq_ids.index(parent_seq_id) + output_logprobs = next_logprobs[parent_seq_id].copy() + output_logprobs[next_token_id] = logprob[i, next_token_id].item() + seq_outputs[seq_id] = SequenceOutputs( + seq_id, + parent_seq_id, + next_token_id, + output_logprobs, + ) - # Return the next tokens. - next_tokens: Dict[int, Tuple[int, int]] = {} - for seq_id, token_id in zip(input_metadata.seq_ids, next_token_ids): - next_tokens[seq_id] = (seq_id, token_id) - return next_tokens + return seq_outputs diff --git a/cacheflow/sampling_params.py b/cacheflow/sampling_params.py index 241d248a0b602..5f446198bd67c 100644 --- a/cacheflow/sampling_params.py +++ b/cacheflow/sampling_params.py @@ -5,27 +5,51 @@ class SamplingParams: def __init__( self, - n: int = 1, - temperature: float = 1.0, - top_p: float = 1.0, - use_beam_search: bool = False, - stop_token_ids: Set[int] = [], - max_num_steps: int = 16, # From OpenAI API. - max_context_len: Optional[int] = None, + n: int, + temperature: float, + top_p: float, + use_beam_search: bool, + stop_token_ids: Set[int], + max_num_steps: int, + num_logprobs: int, + context_window_size: Optional[int], ) -> None: - assert n >= 1 - assert temperature >= 0.0 - assert 0.0 < top_p <= 1.0 + if n < 1: + raise ValueError(f'n must be at least 1, got {n}.') + if temperature < 0.0: + raise ValueError( + f'temperature must be non-negative, got {temperature}.') + if not 0.0 < top_p <= 1.0: + raise ValueError(f'top_p must be in (0, 1], got {top_p}.') + if max_num_steps < 1: + raise ValueError( + f'max_num_steps must be at least 1, got {max_num_steps}.') + if num_logprobs < 0: + raise ValueError( + f'num_logprobs must be non-negative, got {num_logprobs}.') + if context_window_size is not None and context_window_size < 0: + raise ValueError( + 'context_window_size must be non-negative, ' + f'got {context_window_size}.') + if use_beam_search: - assert n > 1 - assert temperature > 0.0 - assert top_p == 1.0 + if n == 1: + raise ValueError( + 'n must be greater than 1 when using beam search.') + if temperature > 0.0: + raise ValueError( + 'temperature must be 0 when using beam search.') + if top_p < 1.0: + raise ValueError( + 'top_p must be 1 when using beam search.') elif temperature == 0.0: - # Zero temperature means greedy decoding. - assert n == 1 - assert top_p == 1.0 - assert max_num_steps >= 1 - assert max_context_len is None or max_context_len >= 0 + # Zero temperature means greedy sampling. + if n > 1: + raise ValueError( + 'n must be 1 when using greedy sampling.') + if top_p < 1.0: + raise ValueError( + 'top_p must be 1 when using greedy sampling.') self.n = n self.temperature = temperature @@ -33,4 +57,15 @@ def __init__( self.use_beam_search = use_beam_search self.stop_token_ids = stop_token_ids self.max_num_steps = max_num_steps - self.max_context_len = max_context_len + self.num_logprobs = num_logprobs + self.context_window_size = context_window_size + + def __repr__(self) -> str: + return (f'SamplingParams(n={self.n}, ' + f'temperature={self.temperature}, ' + f'top_p={self.top_p}, ' + f'use_beam_search={self.use_beam_search}, ' + f'stop_token_ids={self.stop_token_ids}, ' + f'max_num_steps={self.max_num_steps}, ' + f'num_logprobs={self.num_logprobs}, ' + f'context_window_size={self.context_window_size})') diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index 364f3b2e69e21..fb9e9daba01ec 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -1,7 +1,9 @@ +import copy import enum -from typing import List, Optional +from typing import Dict, List, Optional from cacheflow.block import LogicalTokenBlock +from cacheflow.sampling_params import SamplingParams class SequenceStatus(enum.Enum): @@ -24,9 +26,11 @@ def __init__( self.logical_token_blocks: List[LogicalTokenBlock] = [] # Initialize the logical token blocks with the given token ids. - self.append(token_ids) + self.add(token_ids) self.status = SequenceStatus.PENDING + self.output_logprobs: List[Dict[int, float]] = [] + self.cumulative_logprobs = 1.0 def add_block(self) -> None: block = LogicalTokenBlock( @@ -35,7 +39,7 @@ def add_block(self) -> None: ) self.logical_token_blocks.append(block) - def append(self, token_ids: List[int]) -> None: + def add(self, token_ids: List[int]) -> None: while token_ids: if not self.logical_token_blocks: self.add_block() @@ -49,6 +53,12 @@ def append(self, token_ids: List[int]) -> None: last_block.append(token_ids[:num_empty_slots]) token_ids = token_ids[num_empty_slots:] + def append(self, token_id: int, logprobs: Dict[int, float]) -> None: + assert token_id in logprobs + self.add([token_id]) + self.output_logprobs.append(logprobs) + self.cumulative_logprobs += logprobs[token_id] + def get_len(self) -> int: return sum(block.num_tokens for block in self.logical_token_blocks) @@ -58,6 +68,14 @@ def get_token_ids(self) -> List[int]: token_ids.extend(block.get_token_ids()) return token_ids + def get_last_token_id(self) -> int: + return self.logical_token_blocks[-1].get_last_token_id() + + def fork(self, child_seq: 'Sequence') -> 'Sequence': + child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks) + child_seq.output_logprobs = copy.deepcopy(self.output_logprobs) + child_seq.cumulative_logprobs = self.cumulative_logprobs + def __repr__(self) -> str: return (f'Sequence(seq_id={self.seq_id}, ' f'status={self.status.name}, ' @@ -74,11 +92,17 @@ def __init__( self.group_id = group_id self.seqs = seqs - def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: + def get_seqs( + self, + status: Optional[SequenceStatus] = None, + ) -> List[Sequence]: if status is None: - return len(self.seqs) + return self.seqs else: - return len([seq for seq in self.seqs if seq.status == status]) + return [seq for seq in self.seqs if seq.status == status] + + def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: + return len(self.get_seqs(status)) def find(self, seq_id: int) -> Sequence: for seq in self.seqs: @@ -92,3 +116,45 @@ def is_finished(self) -> bool: def __repr__(self) -> str: return (f'SequenceGroup(group_id={self.group_id}, ' f'num_seqs={len(self.seqs)})') + + +class SequenceGroupInputs: + + def __init__( + self, + group_id: int, + is_prompt: bool, + input_tokens: Dict[int, List[int]], # Seq id -> token ids. + context_len: int, + seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs. + sampling_params: SamplingParams, + block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers. + ) -> None: + self.group_id = group_id + self.is_prompt = is_prompt + self.input_tokens = input_tokens + self.context_len = context_len + self.seq_logprobs = seq_logprobs + self.sampling_params = sampling_params + self.block_tables = block_tables + + +class SequenceOutputs: + + def __init__( + self, + seq_id: int, + parent_seq_id: int, + output_token: int, + logprobs: Dict[int, float], # Token id -> logP(x_i+1 | x_0, ..., x_i). + ) -> None: + self.seq_id = seq_id + self.parent_seq_id = parent_seq_id + self.output_token = output_token + self.logprobs = logprobs + + def __repr__(self) -> str: + return (f'SequenceOutputs(seq_id={self.seq_id}, ' + f'parent_seq_id={self.parent_seq_id}, ' + f'output_token={self.output_token}), ' + f'logprobs={self.logprobs}') diff --git a/cacheflow/worker/cache_engine.py b/cacheflow/worker/cache_engine.py index d8597a3b1cf1e..cf509abb647b5 100644 --- a/cacheflow/worker/cache_engine.py +++ b/cacheflow/worker/cache_engine.py @@ -97,7 +97,7 @@ def allocate_cpu_cache(self) -> List[KVCache]: cpu_cache.append((key_blocks, value_blocks)) return cpu_cache - def _copy_blocks( + def _swap( self, src: List[KVCache], dst: List[KVCache], @@ -108,19 +108,38 @@ def _copy_blocks( src_key_cache, src_value_cache = src[i] dst_key_cache, dst_value_cache = dst[i] # Copy the key blocks. - cache_ops.copy_cache_blocks( + cache_ops.swap_blocks( src_key_cache, dst_key_cache, src_to_dst) # Copy the value blocks. - cache_ops.copy_cache_blocks( + cache_ops.swap_blocks( src_value_cache, dst_value_cache, src_to_dst) event = self.events[i] event.record(stream=self.cache_stream) - def copy(self, src_to_dst: Dict[int, int]) -> None: - self._copy_blocks(self.gpu_cache, self.gpu_cache, src_to_dst) - def swap_in(self, src_to_dst: Dict[int, int]) -> None: - self._copy_blocks(self.cpu_cache, self.gpu_cache, src_to_dst) + self._swap(self.cpu_cache, self.gpu_cache, src_to_dst) def swap_out(self, src_to_dst: Dict[int, int]) -> None: - self._copy_blocks(self.gpu_cache, self.cpu_cache, src_to_dst) + self._swap(self.gpu_cache, self.cpu_cache, src_to_dst) + + def _copy( + self, + src: List[KVCache], + dst: List[KVCache], + src_to_dsts: Dict[int, List[int]], + ) -> None: + with torch.cuda.stream(self.cache_stream): + for i in range(self.num_layers): + src_key_cache, src_value_cache = src[i] + dst_key_cache, dst_value_cache = dst[i] + # Copy the key blocks. + cache_ops.copy_blocks( + src_key_cache, dst_key_cache, src_to_dsts) + # Copy the value blocks. + cache_ops.copy_blocks( + src_value_cache, dst_value_cache, src_to_dsts) + event = self.events[i] + event.record(stream=self.cache_stream) + + def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: + self._copy(self.gpu_cache, self.gpu_cache, src_to_dsts) diff --git a/cacheflow/worker/controller.py b/cacheflow/worker/controller.py index f75804f1da13a..31bd03e0c20bf 100644 --- a/cacheflow/worker/controller.py +++ b/cacheflow/worker/controller.py @@ -1,6 +1,7 @@ from typing import Dict, List, Union from cacheflow.master.scheduler import Scheduler +from cacheflow.sequence import SequenceGroupInputs from cacheflow.worker.worker import Worker @@ -14,7 +15,8 @@ def __init__( block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, - dtype: str = 'half', + dtype: str, + seed: int, ) -> None: self.node_id = node_id self.num_workers = num_workers @@ -37,6 +39,7 @@ def __init__( num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, dtype=dtype, + seed=seed, ) self.workers.append(worker) @@ -49,22 +52,16 @@ def set_next( def execute_stage( self, - prompt_tokens: Dict[int, List[int]], - generation_tokens: Dict[int, int], - context_lens: Dict[int, int], - block_tables: Dict[int, List[int]], + input_seq_groups: List[SequenceGroupInputs], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], ) -> None: # FIXME: Support tensor parallelism. assert len(self.workers) == 1 worker = self.workers[0] output = worker.execute_stage( - prompt_tokens, - generation_tokens, - context_lens, - block_tables, + input_seq_groups, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy, diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index aabbb7e3f1e72..84d8cdd5390a2 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -1,9 +1,13 @@ -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple import torch from cacheflow.models import get_model +from cacheflow.models import set_seed from cacheflow.models import InputMetadata +from cacheflow.sampling_params import SamplingParams +from cacheflow.sequence import SequenceGroupInputs +from cacheflow.sequence import SequenceOutputs from cacheflow.worker.cache_engine import CacheEngine @@ -18,6 +22,7 @@ def __init__( num_gpu_blocks: int, num_cpu_blocks: int, dtype: str, + seed: int, ) -> None: self.worker_id = worker_id self.gpu_id = gpu_id @@ -33,6 +38,11 @@ def __init__( self.head_size = self.model.config.hidden_size // self.num_heads self.dtype = self.model.dtype + # Set the seed. + # We set the seed after initializing the model to ensure that + # the random state is not affected by the model initialization. + set_seed(seed) + self.cache_engine = CacheEngine( worker_id=worker_id, gpu_id=gpu_id, @@ -49,55 +59,81 @@ def __init__( def prepare_inputs( self, - prompt_tokens: Dict[int, List[int]], # Seq id -> List of input token ids. - generation_tokens: Dict[int, int], # Seq id -> Input token id. - context_lens: Dict[int, int], # Seq id -> Number of tokens participating in attention. - block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers. + input_seq_groups: List[SequenceGroupInputs], ) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]: - # TODO(woosuk): Support interactive generation. - # Add the prompt tokens. - prompt_lens: List[int] = [] + seq_groups: List[Tuple[List[int], SamplingParams]] = [] + seq_logprobs: Dict[int, float] = {} + sampling_params: Dict[int, SamplingParams] = {} input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - prompt_seq_ids = sorted(prompt_tokens.keys()) - for seq_id in prompt_seq_ids: - prompt_len = len(prompt_tokens[seq_id]) + # Add prompt tokens. + prompt_lens: List[int] = [] + for input_seq_group in input_seq_groups: + if not input_seq_group.is_prompt: + continue + + seq_ids = list(input_seq_group.input_tokens.keys()) + sampling_params = input_seq_group.sampling_params + seq_groups.append((seq_ids, sampling_params)) + seq_logprobs.update(input_seq_group.seq_logprobs) + + # Use any sequence in the group. + seq_id = seq_ids[0] + + prompt_tokens = input_seq_group.input_tokens[seq_id] + prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) - input_tokens.extend(prompt_tokens[seq_id]) - input_positions.extend(range(len(prompt_tokens[seq_id]))) + input_tokens.extend(prompt_tokens) + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.extend(range(len(prompt_tokens))) - block_table = block_tables[seq_id] + # Compute the slot mapping. + block_table = input_seq_group.block_tables[seq_id] for i in range(prompt_len): block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - # Add the generation tokens. + # Add generation tokens. max_context_len = 0 max_num_blocks_per_seq = 0 + context_lens: List[int] = [] generation_block_tables: List[List[int]] = [] - - generation_seq_ids = sorted(generation_tokens.keys()) - for seq_id in generation_seq_ids: - input_tokens.append(generation_tokens[seq_id]) - position_id = context_lens[seq_id] - 1 - input_positions.append(position_id) - - block_table = block_tables[seq_id] - generation_block_tables.append(block_table) - - max_context_len = max(max_context_len, context_lens[seq_id]) - max_num_blocks_per_seq = max( - max_num_blocks_per_seq, len(block_table)) - - block_number = block_table[position_id // self.block_size] - block_offset = position_id % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) + for input_seq_group in input_seq_groups: + if input_seq_group.is_prompt: + continue + + seq_ids = list(input_seq_group.input_tokens.keys()) + sampling_params = input_seq_group.sampling_params + seq_groups.append((seq_ids, sampling_params)) + seq_logprobs.update(input_seq_group.seq_logprobs) + + for seq_id in seq_ids: + assert len(input_seq_group.input_tokens[seq_id]) == 1 + generation_token = input_seq_group.input_tokens[seq_id][0] + input_tokens.append(generation_token) + + position = input_seq_group.context_len - 1 + input_positions.append(position) + + block_table = input_seq_group.block_tables[seq_id] + generation_block_tables.append(block_table) + + max_context_len = max( + max_context_len, input_seq_group.context_len) + max_num_blocks_per_seq = max( + max_num_blocks_per_seq, len(block_table)) + context_lens.append(input_seq_group.context_len) + + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) # Optimization: Pad the input length to be a multiple of 8. # This is required for utilizing the Tensor Cores in NVIDIA GPUs. @@ -112,8 +148,7 @@ def prepare_inputs( slot_mapping_tensor = torch.tensor( slot_mapping, dtype=torch.int, device=self.device) context_lens_tensor = torch.tensor( - [context_lens[seq_id] for seq_id in generation_seq_ids], - dtype=torch.int, device=self.device) + context_lens, dtype=torch.int, device=self.device) padded_block_tables = [ _pad_to_max(block_table, max_num_blocks_per_seq) for block_table in generation_block_tables] @@ -121,7 +156,8 @@ def prepare_inputs( padded_block_tables, dtype=torch.int, device=self.device) input_metadata = InputMetadata( - seq_ids=prompt_seq_ids + generation_seq_ids, + seq_groups=seq_groups, + seq_logprobs=seq_logprobs, prompt_lens=prompt_lens, slot_mapping=slot_mapping_tensor, context_lens=context_lens_tensor, @@ -133,14 +169,11 @@ def prepare_inputs( @torch.inference_mode() def execute_stage( self, - prompt_tokens: Dict[int, List[int]], # Seq id -> List of input token ids. - generation_tokens: Dict[int, int], # Seq id -> Input token id. - context_lens: Dict[int, int], # Seq id -> Number of tokens participating in attention. - block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers. + input_seq_groups: List[SequenceGroupInputs], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, int], - ) -> Union[torch.Tensor, Dict[int, Tuple[int, int]]]: + blocks_to_copy: Dict[int, List[int]], + ) -> Dict[int, SequenceOutputs]: # Issue cache operations. command_issued = False if blocks_to_swap_in: @@ -160,7 +193,7 @@ def execute_stage( # Prepare input tensors. input_tokens, input_positions, input_metadata = self.prepare_inputs( - prompt_tokens, generation_tokens, context_lens, block_tables) + input_seq_groups) # Execute the model. output = self.model( diff --git a/csrc/cache.cpp b/csrc/cache.cpp index e20d1c3f114a5..fcf8b69fe8608 100644 --- a/csrc/cache.cpp +++ b/csrc/cache.cpp @@ -1,10 +1,18 @@ #include -void copy_blocks( +#include +#include + +void swap_blocks( torch::Tensor& src, torch::Tensor& dst, const std::map& block_mapping); +void copy_blocks( + torch::Tensor& src, + torch::Tensor& dst, + const std::map>& block_mapping); + void reshape_and_cache( torch::Tensor& key, torch::Tensor& value, @@ -14,7 +22,11 @@ void reshape_and_cache( PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( - "copy_cache_blocks", + "swap_blocks", + &swap_blocks, + "Swap in (out) the cache blocks from src to dst"); + m.def( + "copy_blocks", ©_blocks, "Copy the cache blocks from src to dst"); m.def( diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 2e9ca8f0df7ba..0366b922c28db 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -5,8 +5,9 @@ #include #include #include +#include -void copy_blocks( +void swap_blocks( torch::Tensor& src, torch::Tensor& dst, const std::map& block_mapping) { @@ -43,6 +44,35 @@ void copy_blocks( } } +void copy_blocks( + torch::Tensor& src, + torch::Tensor& dst, + const std::map>& block_mapping) { + torch::Device src_device = src.device(); + torch::Device dst_device = dst.device(); + assert(src_device.is_cuda() && dst_device.is_cuda()); + cudaMemcpyKind memcpy_type = cudaMemcpyDeviceToDevice; + + void *src_ptr = src.data_ptr(); + void *dst_ptr = dst.data_ptr(); + + const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + for (const auto& pair : block_mapping) { + int64_t src_block_number = pair.first; + for (int64_t dst_block_number : pair.second) { + int64_t src_offset = src_block_number * block_size_in_bytes; + int64_t dst_offset = dst_block_number * block_size_in_bytes; + cudaMemcpyAsync( + dst_ptr + dst_offset, + src_ptr + src_offset, + block_size_in_bytes, + memcpy_type, + stream); + } + } +} + template __global__ void reshape_and_cache_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] diff --git a/server.py b/server.py index 04e2f6d726693..c873caf15d5d6 100644 --- a/server.py +++ b/server.py @@ -15,6 +15,8 @@ parser.add_argument('--num-cpu-blocks', type=int, default=32, help='number of CPU blocks (per GPU)') # NOTE(woosuk): If FlashAttention is used, the float data type is not supported. parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') +# TODO(woosuk): Support fine-grained seeds (e.g., seed per request). +parser.add_argument('--seed', type=int, default=0, help='random seed') args = parser.parse_args() @@ -30,6 +32,7 @@ def main(): num_gpu_blocks=args.num_gpu_blocks, num_cpu_blocks=args.num_cpu_blocks, dtype=args.dtype, + seed=args.seed, ) controllers.append(controller) @@ -52,18 +55,18 @@ def main(): controllers[i].set_next(controllers[i + 1]) controllers[-1].set_next(scheduler) + # Test the following inputs. test_inputs = [ - 'Ion Stoica is a', - 'UC Berkeley is', - 'The future of cloud computing is', + ('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}), + ('UC Berkeley is', {'n': 3, 'temperature': 0.8, 'top_p': 0.99}), + ('The future of cloud computing is', {}), # Use default parameters. ] - - # FIXME while True: if test_inputs: - frontend.query(test_inputs.pop()) + text, sampling_params = test_inputs.pop(0) + frontend.query(text, **sampling_params) scheduler.step() - if not scheduler.pending and not scheduler.running: + if not (scheduler.pending or scheduler.running or test_inputs): break