Skip to content

Commit

Permalink
Support beam search & parallel generation (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Mar 10, 2023
1 parent 04e5acc commit 1a7eb7d
Show file tree
Hide file tree
Showing 16 changed files with 662 additions and 163 deletions.
4 changes: 4 additions & 0 deletions cacheflow/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def append(self, token_ids: List[int]) -> None:
def get_token_ids(self) -> List[int]:
return self.token_ids[:self.num_tokens]

def get_last_token_id(self) -> int:
assert self.num_tokens > 0
return self.token_ids[self.num_tokens - 1]


class PhysicalTokenBlock:

Expand Down
33 changes: 28 additions & 5 deletions cacheflow/master/frontend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Set, Tuple

from transformers import AutoTokenizer

Expand All @@ -25,12 +25,35 @@ def __init__(
def query(
self,
prompt: str,
sampling_params: Optional[SamplingParams] = None,
n: int = 1,
temperature: float = 1.0,
top_p: float = 1.0,
use_beam_search: bool = False,
stop_token_ids: Set[int] = set(),
max_num_steps: int = 16, # From OpenAI API.
num_logprobs: int = 0,
context_window_size: Optional[int] = None,
) -> None:
if sampling_params is None:
sampling_params = SamplingParams()
token_ids: List[int] = self.tokenizer.encode(prompt)
# Stop when we see an EOS token.
stop_token_ids.add(self.tokenizer.eos_token_id)
sampling_params = SamplingParams(
n=n,
temperature=temperature,
top_p=top_p,
use_beam_search=use_beam_search,
stop_token_ids=stop_token_ids,
max_num_steps=max_num_steps,
num_logprobs=num_logprobs,
context_window_size=context_window_size,
)
token_ids = self.tokenizer.encode(prompt)
self._add_query(token_ids, sampling_params)

def _add_query(
self,
token_ids: List[int],
sampling_params: SamplingParams,
) -> None:
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)
Expand Down
91 changes: 55 additions & 36 deletions cacheflow/master/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Dict, List, Tuple
from typing import Dict, List

from cacheflow.master.block_manager import BlockSpaceManager
from cacheflow.master.frontend import Frontend
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceOutputs
from cacheflow.sequence import SequenceStatus

_MAX_NUM_BATCHED_TOKENS = 2048
Expand Down Expand Up @@ -66,15 +68,18 @@ def _allocate(self, seq_group: SequenceGroup) -> None:
def _append(
self,
seq_group: SequenceGroup,
blocks_to_copy: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
ret = self.block_manager.append(seq)
if ret is not None:
src_block, dst_block = ret
blocks_to_copy[src_block] = dst_block
if src_block in blocks_to_copy:
blocks_to_copy[src_block].append(dst_block)
else:
blocks_to_copy[src_block] = [dst_block]

def _swap_in(
self,
Expand All @@ -83,9 +88,8 @@ def _swap_in(
) -> None:
mapping = self.block_manager.swap_in(seq_group)
blocks_to_swap_in.update(mapping)
for seq in seq_group.seqs:
if seq.status == SequenceStatus.SWAPPED:
seq.status = SequenceStatus.RUNNING
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
seq.status = SequenceStatus.RUNNING
self.running.append(seq_group)

def _swap_out(
Expand All @@ -96,16 +100,15 @@ def _swap_out(
assert self.block_manager.can_swap_out(seq_group)
mapping = self.block_manager.swap_out(seq_group)
blocks_to_swap_out.update(mapping)
for seq in seq_group.seqs:
if seq.status == SequenceStatus.RUNNING:
seq.status = SequenceStatus.SWAPPED
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED
self.swapped.append(seq_group)

def step(self) -> None:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {}

# 1. Reserve new slots for the running sequences.
# NOTE: Here we implicitly assume FCFS scheduling.
Expand Down Expand Up @@ -143,6 +146,10 @@ def step(self) -> None:
# All swapped sequences are swapped in.
self.swapped.clear()

# Ensure that swap-in and swap-out never happen at the same timestep.
if blocks_to_swap_in:
assert not blocks_to_swap_out

num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running
Expand All @@ -152,7 +159,6 @@ def step(self) -> None:
# NOTE: Here we implicitly assume FCFS scheduling.
# TODO(woosuk): Add a batching policy to control the batch size.
if not self.swapped:
# FIXME(woosuk): Acquire a lock to protect pending.
self._fetch_inputs()
for i, seq_group in enumerate(self.pending):
num_prompt_tokens = seq_group.seqs[0].get_len()
Expand All @@ -168,73 +174,86 @@ def step(self) -> None:
else:
self.pending.clear()

# Ensure that swap-in and swap-out never happen at the same timestep.
if blocks_to_swap_in:
assert not blocks_to_swap_out

# 4. Create input data structures.
prompt_tokens: Dict[int, List[int]] = {}
generation_tokens: Dict[int, int] = {}
context_lens: Dict[int, int] = {}
block_tables: Dict[int, List[int]] = {}
input_seq_groups: List[SequenceGroupInputs] = []
for seq_group in self.running:
group_id = seq_group.group_id
num_steps = self.num_steps[group_id]

# NOTE(woosuk): We assume that the number of steps is 0
# for the prompt sequences.
is_prompt = num_steps == 0
for seq in seq_group.seqs:
if seq.status != SequenceStatus.RUNNING:
continue

input_tokens: Dict[int, List[int]] = {}
seq_logprobs: Dict[int, float] = {}
block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
block_tables[seq_id] = self.block_manager.get_block_table(seq)
if is_prompt:
prompt_tokens[seq_id] = seq.get_token_ids()
input_tokens[seq_id] = seq.get_token_ids()
else:
generation_tokens[seq_id] = seq.get_token_ids()[-1]
context_lens[seq_id] = seq.get_len()
input_tokens[seq_id] = [seq.get_last_token_id()]
seq_logprobs[seq_id] = seq.cumulative_logprobs
# NOTE(woosuk): Sequences in the same group have the same
# sequence length
seq_len = seq.get_len()

input_seq_group = SequenceGroupInputs(
group_id=group_id,
is_prompt=is_prompt,
input_tokens=input_tokens,
context_len=seq_len,
seq_logprobs=seq_logprobs,
sampling_params=self.sampling_params[group_id],
block_tables=block_tables,
)
input_seq_groups.append(input_seq_group)

# 5. Execute the first stage of the pipeline.
self.controllers[0].execute_stage(
prompt_tokens,
generation_tokens,
context_lens,
block_tables,
input_seq_groups,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
)

def post_step(
self,
next_tokens: Dict[int, Tuple[int, int]],
seq_outputs: Dict[int, SequenceOutputs],
) -> None:
# Update the running sequences and free blocks.
for seq_group in self.running:
group_id = seq_group.group_id
self.num_steps[group_id] += 1
stop_token_ids = self.sampling_params[group_id].stop_token_ids

# Process beam search results before processing the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue

parent_seq_id, next_token = next_tokens[seq.seq_id]
if seq.seq_id != parent_seq_id:
output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam search).
# Free the current sequence.
self.block_manager.free(seq)
# Fork the parent sequence.
parent_seq = seq_group.find(parent_seq_id)
seq.logical_token_blocks = parent_seq.logical_token_blocks.copy()
parent_seq = seq_group.find(output.parent_seq_id)
parent_seq.fork(seq)
self.block_manager.fork(parent_seq, seq)

# Process the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue

# Append a new token to the sequence.
seq.append([next_token])
output = seq_outputs[seq.seq_id]
seq.append(output.output_token, output.logprobs)

# Check if the sequence has generated a stop token.
if next_token in stop_token_ids:
if output.output_token in stop_token_ids:
self._free_seq(seq)
continue

Expand Down
4 changes: 3 additions & 1 deletion cacheflow/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from cacheflow.models.input_metadata import InputMetadata
from cacheflow.models.model_utils import get_model
from cacheflow.models.model_utils import set_seed


__all__ = [
'get_model',
'InputMetadata',
'get_model',
'set_seed'
]
18 changes: 11 additions & 7 deletions cacheflow/models/input_metadata.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,45 @@
from typing import List
from typing import List, Dict, Tuple

import torch

from cacheflow.sampling_params import SamplingParams


class InputMetadata:

def __init__(
self,
seq_ids: List[int],
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
prompt_lens: List[int],
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
# FIXME: Rename
max_context_len: int,
block_tables: torch.Tensor,
) -> None:
self.seq_ids = seq_ids
self.seq_groups = seq_groups
self.seq_logprobs = seq_logprobs
self.prompt_lens = prompt_lens
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables

self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = slot_mapping.shape[0]
if block_tables.numel() > 0:
self.max_num_blocks_per_seq = block_tables.shape[1]
else:
self.max_num_blocks_per_seq = 0
assert self.num_generation_tokens == block_tables.shape[0]
assert self.num_prompts + self.num_generation_tokens == len(seq_ids)
assert block_tables.shape[0] == self.num_generation_tokens
assert context_lens.shape[0] == self.num_generation_tokens

def __repr__(self) -> str:
return (f'InputMetadata('
f'seq_ids={self.seq_ids}, '
f'num_prompts={self.num_prompts}, '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'num_valid_tokens={self.num_valid_tokens}, '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
Expand Down
10 changes: 10 additions & 0 deletions cacheflow/models/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import random
from typing import Union

import numpy as np
import torch
import torch.nn as nn

Expand Down Expand Up @@ -30,3 +32,11 @@ def get_model(
model = hf_model.from_pretrained(model_name, torch_dtype=torch_dtype)
return model.eval()
raise ValueError(f'Invalid model name: {model_name}')


def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
3 changes: 2 additions & 1 deletion cacheflow/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.sample import Sampler
from cacheflow.sequence import SequenceOutputs

KVCache = Tuple[torch.Tensor, torch.Tensor]

Expand Down Expand Up @@ -261,7 +262,7 @@ def forward(
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, Tuple[int, int]]:
) -> Dict[int, SequenceOutputs]:
hidden_states = self.model(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
Expand Down
Loading

0 comments on commit 1a7eb7d

Please sign in to comment.