Skip to content

Commit

Permalink
Merged PR 1865: Critical bug fixes related to sampling (#18)
Browse files Browse the repository at this point in the history
- Fix multi-category sampling
- Fix sampling in mix batches -- due to incorrect ordering of requests within batch
- Fix orca/faster transformers decode attention with flashinfer

Co-authored-by: Amey Agrawal <t-amagrawal@microsoft.com>
  • Loading branch information
AgrawalAmey and Amey Agrawal authored Jun 23, 2024
1 parent 8245cba commit 50e59c5
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 81 deletions.
11 changes: 5 additions & 6 deletions examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from tqdm import tqdm
from typing import List

from sarathi.types import AttentionBackend
from sarathi.config import ModelConfig, ParallelConfig, SarathiSchedulerConfig, MetricsConfig, SystemConfig, ReplicaConfig, WorkerConfig
from sarathi.config import ModelConfig, ParallelConfig, SarathiSchedulerConfig, MetricsConfig, SystemConfig, ReplicaConfig
from sarathi import LLMEngine, SamplingParams, RequestOutput


Expand Down Expand Up @@ -31,17 +30,17 @@
)

model_config = ModelConfig(
model="meta-llama/Llama-2-7b-hf",
model="meta-llama/Meta-Llama-3-8B-Instruct",
)

parallel_config = ParallelConfig(
tensor_parallel_size=2,
pipeline_parallel_size=2,
tensor_parallel_size=1,
pipeline_parallel_size=1,
)

scheduler_config = SarathiSchedulerConfig(
chunk_size=100,
max_num_seqs=4,
max_num_seqs=10,
)

metrics_config = MetricsConfig(
Expand Down
6 changes: 3 additions & 3 deletions sarathi/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class ModelConfig:
output). If None, will be derived from the model.
"""

model: str = "meta-llama/Meta-Llama-3-8B"
model: str = "meta-llama/Meta-Llama-3-8B-Instruct"
trust_remote_code: bool = True
download_dir: Optional[str] = None
load_format: str = "auto"
Expand Down Expand Up @@ -167,7 +167,7 @@ class ParallelConfig:
"""

pipeline_parallel_size: int = 1
tensor_parallel_size: int = 8
tensor_parallel_size: int = 1

def __post_init__(self):
self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size
Expand Down Expand Up @@ -243,7 +243,7 @@ def get_type():

@dataclass
class SarathiSchedulerConfig(BaseSchedulerConfig):
chunk_size: int = 1024
chunk_size: int = 512
enable_dynamic_chunking_schedule: bool = False
low_chunk_size: Optional[int] = None
high_chunk_size: Optional[int] = None
Expand Down
2 changes: 1 addition & 1 deletion sarathi/core/datatypes/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
top_k: int = -1,
stop: Union[None, str, List[str]] = None,
ignore_eos: bool = False,
max_tokens: int = 16,
max_tokens: int = 2048,
) -> None:
self.temperature = temperature
self.top_p = top_p
Expand Down
4 changes: 3 additions & 1 deletion sarathi/core/datatypes/scheduler_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def __init__(
self.id = id
self.ignored_seq_ids = ignored_seq_ids
self.preempted_seq_ids = preempted_seq_ids
self.scheduled_seq_metadata_list = scheduled_seq_metadata_list
self.scheduled_seq_metadata_list = sorted(
scheduled_seq_metadata_list, key=lambda x: not x.is_prompt
)
self.prompt_chunk_lens = [
metadata.num_prompt_tokens for metadata in scheduled_seq_metadata_list
]
Expand Down
2 changes: 2 additions & 0 deletions sarathi/core/scheduler/orca_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def _schedule(self) -> SchedulerOutputs:

now = time.monotonic()

self.running = self.policy.sort_by_priority(now, self.running)

for seq in self.running:
if not seq.is_paused():
continue
Expand Down
4 changes: 2 additions & 2 deletions sarathi/core/scheduler/vllm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def __init__(
) -> None:
super().__init__(model_config, scheduler_config, cache_config)

self.max_batched_tokens = self.scheduler_config.get_max_num_batched_tokens(
self.max_num_batched_tokens = self.scheduler_config.get_max_num_batched_tokens(
self.model_config.max_model_len
)
self.prompt_limit = self.max_batched_tokens
self.prompt_limit = self.max_num_batched_tokens

def get_block_space_manager_class(self):
return VLLMBlockSpaceManager
Expand Down
2 changes: 2 additions & 0 deletions sarathi/core/sequence_manager/base_sequence_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def on_schedule(
self._preempt_seq(seq_id)

seq_metadata_list: List[SequenceMetadata] = []

for seq_sched_metadata in scheduler_outputs.scheduled_seq_metadata_list:
self._on_seq_scheduled(seq_sched_metadata)
seq = self.seq_map[seq_sched_metadata.seq_id]
Expand Down Expand Up @@ -116,6 +117,7 @@ def on_step_completed(
for scheduled_seq_metadata, sampler_output in zip(
scheduler_outputs.scheduled_seq_metadata_list, sampler_outputs
):
assert scheduled_seq_metadata.seq_id == sampler_output.seq_id
seq = self.seq_map[scheduled_seq_metadata.seq_id]
if seq.is_waiting():
# seq is preempted
Expand Down
4 changes: 4 additions & 0 deletions sarathi/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def process_request_output(
"""Process a request output from the engine."""
request_id = request_output.seq_id

if request_id not in self._request_streams:
# aborted request
return

self._request_streams[request_id].put(request_output)
if request_output.finished:
if verbose:
Expand Down
2 changes: 1 addition & 1 deletion sarathi/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
# https://platform.openai.com/docs/api-reference/chat/create
messages: List[ChatCompletionMessageParam]
model: str
max_tokens: Optional[int] = 16
max_tokens: Optional[int] = 2048
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
Expand Down
Loading

0 comments on commit 50e59c5

Please sign in to comment.