Skip to content

Commit

Permalink
Merged PR 1857: Add OpenAI Server (#16)
Browse files Browse the repository at this point in the history
- Ports OpenAI server from vLLM.
- Adds AsyncLLMEngine.
  • Loading branch information
AgrawalAmey authored Jun 20, 2024
1 parent a1fec20 commit fd3a31e
Show file tree
Hide file tree
Showing 40 changed files with 2,231 additions and 206 deletions.
1 change: 0 additions & 1 deletion examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import List

from sarathi.config import ModelConfig, ParallelConfig, SarathiSchedulerConfig, MetricsConfig, SystemConfig, ReplicaConfig
from sarathi.types import SchedulerType
from sarathi import LLMEngine, SamplingParams, RequestOutput


Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ tiktoken
grpcio
tqdm
vllm-flash-attn
uvicorn
fastapi
3 changes: 2 additions & 1 deletion sarathi/benchmark/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

import wandb
from sarathi import LLMEngine, SamplingParams
from sarathi.benchmark.config import BenchmarkConfig
from sarathi.benchmark.entities import Request
from sarathi.benchmark.request_generator import RequestGeneratorRegistry
from sarathi.benchmark.utils.random import set_seeds
from sarathi.config import BenchmarkConfig, ReplicaConfig
from sarathi.config import ReplicaConfig
from sarathi.metrics.metrics_store import MetricsStore
from sarathi.types import ReplicaResourceMapping, ResourceMapping
from sarathi.utils import get_ip
Expand Down
180 changes: 180 additions & 0 deletions sarathi/benchmark/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import datetime
from dataclasses import dataclass, field
from typing import Optional

from sarathi.config import (
BaseEndpointConfig,
BaseSchedulerConfig,
CacheConfig,
MetricsConfig,
ModelConfig,
ParallelConfig,
ReplicaConfig,
SarathiSchedulerConfig,
SystemConfig,
WorkerConfig,
)
from sarathi.config.base_poly_config import BasePolyConfig
from sarathi.config.flat_dataclass import create_flat_dataclass
from sarathi.logger import init_logger
from sarathi.types import (
ReplicaResourceMapping,
RequestGeneratorType,
RequestIntervalGeneratorType,
RequestLengthGeneratorType,
)

logger = init_logger(__name__)


@dataclass
class BaseRequestIntervalGeneratorConfig(BasePolyConfig):
seed: int = 42


@dataclass
class BaseRequestLengthGeneratorConfig(BasePolyConfig):
seed: int = 42


@dataclass
class TraceRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig):
trace_file: str = (
"data/processed_traces/AzureFunctionsInvocationTraceForTwoWeeksJan2021Processed.csv"
)
start_time: str = "1970-01-04 12:00:00"
end_time: str = "1970-01-04 15:00:00"
time_scale_factor: float = 0.3

@staticmethod
def get_type():
return RequestIntervalGeneratorType.TRACE


@dataclass
class PoissonRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig):
qps: float = 1.0

@staticmethod
def get_type():
return RequestIntervalGeneratorType.POISSON


@dataclass
class GammaRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig):
qps: float = 1.0
cv: float = 0.5

@staticmethod
def get_type():
return RequestIntervalGeneratorType.GAMMA


@dataclass
class StaticRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig):
@staticmethod
def get_type():
return RequestIntervalGeneratorType.STATIC


@dataclass
class TraceRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig):
trace_file: str = (
"data/processed_traces/sharegpt_8k_filtered_stats_llama2_tokenizer.csv"
)
prefill_scale_factor: float = 1
decode_scale_factor: float = 1
max_tokens: int = 4096

@staticmethod
def get_type():
return RequestLengthGeneratorType.TRACE


@dataclass
class ZipfRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig):
theta: float = 0.6
scramble: bool = False
min_tokens: int = 1024
max_tokens: int = 4096
prefill_to_decode_ratio: float = 20.0

@staticmethod
def get_type():
return RequestLengthGeneratorType.ZIPF


@dataclass
class UniformRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig):
min_tokens: int = 1024
max_tokens: int = 4096
prefill_to_decode_ratio: float = 20.0

@staticmethod
def get_type():
return RequestLengthGeneratorType.UNIFORM


@dataclass
class FixedRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig):
prefill_tokens: int = 4096
decode_tokens: int = 512

@staticmethod
def get_type():
return RequestLengthGeneratorType.FIXED


@dataclass
class BaseRequestGeneratorConfig(BasePolyConfig):
seed: int = 42


@dataclass
class SyntheticRequestGeneratorConfig(BaseRequestGeneratorConfig):
length_generator_config: BaseRequestLengthGeneratorConfig = field(
default_factory=FixedRequestLengthGeneratorConfig
)
interval_generator_config: BaseRequestIntervalGeneratorConfig = field(
default_factory=PoissonRequestIntervalGeneratorConfig
)
num_requests: int = 64
duration: float = None

@staticmethod
def get_type():
return RequestGeneratorType.SYNTHETIC


@dataclass
class TraceRequestGeneratorConfig(BaseRequestGeneratorConfig):
trace_file: str = "data/processed_traces/sydney_enterprise.csv"
date: str = "2023-08-21"
prefill_scale_factor: float = 0.3
decode_scale_factor: float = 1
time_scale_factor: float = 0.04
max_tokens: int = 4096

@staticmethod
def get_type():
return RequestGeneratorType.TRACE


@dataclass
class BenchmarkConfig(BaseEndpointConfig):
seed: int = 42
output_dir: str = "benchmark_output"
write_json_trace: bool = True
enable_profiling: bool = False
time_limit: Optional[int] = None
num_replicas: int = 1
replica_resource_mapping: Optional[ReplicaResourceMapping] = None
request_generator_config: BaseRequestGeneratorConfig = field(
default_factory=SyntheticRequestGeneratorConfig
)

def __post_init__(self):
super().__post_init__()

if not self.time_limit:
self.time_limit = float("inf")
2 changes: 1 addition & 1 deletion sarathi/benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import yaml

from sarathi.benchmark.benchmark_runner import BenchmarkRunnerLauncher
from sarathi.benchmark.config import BenchmarkConfig
from sarathi.benchmark.constants import LOGGER_FORMAT, LOGGER_TIME_FORMAT
from sarathi.benchmark.utils.random import set_seeds
from sarathi.config import BenchmarkConfig


def main() -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from abc import ABC, abstractmethod
from typing import List

from sarathi.benchmark.config import BaseRequestGeneratorConfig
from sarathi.benchmark.entities import Request
from sarathi.config import BaseRequestGeneratorConfig


class BaseRequestGenerator(ABC):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod

from sarathi.config import BaseRequestIntervalGeneratorConfig
from sarathi.benchmark.config import BaseRequestIntervalGeneratorConfig


class BaseRequestIntervalGenerator(ABC):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Tuple

from sarathi.config import BaseRequestLengthGeneratorConfig
from sarathi.benchmark.config import BaseRequestLengthGeneratorConfig


class BaseRequestLengthGenerator(ABC):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from scipy.stats import gamma

from sarathi.benchmark.config import GammaRequestIntervalGeneratorConfig
from sarathi.benchmark.request_generator.base_request_interval_generator import (
BaseRequestIntervalGenerator,
)
from sarathi.config import GammaRequestIntervalGeneratorConfig


class GammaRequestIntervalGenerator(BaseRequestIntervalGenerator):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import math
import random

from sarathi.benchmark.config import PoissonRequestIntervalGeneratorConfig
from sarathi.benchmark.request_generator.base_request_interval_generator import (
BaseRequestIntervalGenerator,
)
from sarathi.config import PoissonRequestIntervalGeneratorConfig


class PoissonRequestIntervalGenerator(BaseRequestIntervalGenerator):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List

from sarathi.benchmark.config import SyntheticRequestGeneratorConfig
from sarathi.benchmark.entities import Request
from sarathi.benchmark.request_generator.base_request_generator import (
BaseRequestGenerator,
Expand All @@ -11,7 +12,6 @@
RequestLengthGeneratorRegistry,
)
from sarathi.benchmark.utils.random import set_seeds
from sarathi.config import SyntheticRequestGeneratorConfig


class SyntheticRequestGenerator(BaseRequestGenerator):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import pandas as pd

from sarathi.benchmark.config import TraceRequestGeneratorConfig
from sarathi.benchmark.entities import Request
from sarathi.benchmark.request_generator.base_request_generator import (
BaseRequestGenerator,
)
from sarathi.config import TraceRequestGeneratorConfig

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import pandas as pd

from sarathi.benchmark.config import TraceRequestIntervalGeneratorConfig
from sarathi.benchmark.request_generator.base_request_interval_generator import (
BaseRequestIntervalGenerator,
)
from sarathi.config import TraceRequestIntervalGeneratorConfig

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import numpy as np
import pandas as pd

from sarathi.benchmark.config import TraceRequestLengthGeneratorConfig
from sarathi.benchmark.request_generator.base_request_length_generator import (
BaseRequestLengthGenerator,
)
from sarathi.config import TraceRequestLengthGeneratorConfig

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Tuple

from sarathi.benchmark.config import ZipfRequestLengthGeneratorConfig
from sarathi.benchmark.request_generator.base_request_length_generator import (
BaseRequestLengthGenerator,
)
from sarathi.benchmark.utils.zipf_generator import ZipfGenerator
from sarathi.config import ZipfRequestLengthGeneratorConfig


class ZipfRequestLengthGenerator(BaseRequestLengthGenerator):
Expand Down
Loading

0 comments on commit fd3a31e

Please sign in to comment.