Skip to content

Commit

Permalink
Use pythonic config management (#15)
Browse files Browse the repository at this point in the history
- New pythonic config library based which works natively with dataclasses.
- Refactor code to work with new config library.
  • Loading branch information
AgrawalAmey authored Jun 20, 2024
1 parent d057e06 commit a1fec20
Show file tree
Hide file tree
Showing 66 changed files with 1,565 additions and 1,781 deletions.
8 changes: 4 additions & 4 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dependencies:
- setuptools
- pip
- make
- black
- isort
- flake8
- autopep8
- black=24.4.2
- isort=5.13.2
- flake8=7.1.0
- autopep8=2.3.0
37 changes: 28 additions & 9 deletions examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
from tqdm import tqdm
from typing import List

from sarathi.config import ModelConfig, ParallelConfig, SarathiSchedulerConfig, MetricsConfig, SystemConfig, ReplicaConfig
from sarathi.types import SchedulerType
from sarathi import LLMEngine, SamplingParams, RequestOutput


BASE_OUTPUT_DIR = "./offline_inference_output"

# Sample prompts.
Expand All @@ -23,23 +26,39 @@

output_dir = f"{BASE_OUTPUT_DIR}/{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"

llm_engine = LLMEngine.from_engine_args(
replica_config = ReplicaConfig(
output_dir=output_dir,
)

model_config = ModelConfig(
model="meta-llama/Llama-2-7b-hf",
# parallel config
tensor_parallel_size=4,
)

parallel_config = ParallelConfig(
tensor_parallel_size=2,
pipeline_parallel_size=2,
trust_remote_code=True,
max_model_len=4096,
# scheduler config
scheduler_type="sarathi",
)

scheduler_config = SarathiSchedulerConfig(
chunk_size=100,
max_num_seqs=4,
# metrics config
)

metrics_config = MetricsConfig(
write_metrics=False,
output_dir=output_dir,
enable_chrome_trace=True,
)

system_config = SystemConfig(
replica_config=replica_config,
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
metrics_config=metrics_config,
)

llm_engine = LLMEngine.from_system_config(system_config)


def generate(
llm_engine: LLMEngine,
Expand Down
2 changes: 0 additions & 2 deletions sarathi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from sarathi.core.datatypes.request_output import RequestOutput
from sarathi.core.datatypes.sampling_params import SamplingParams
from sarathi.engine.arg_utils import EngineArgs
from sarathi.engine.llm_engine import LLMEngine

__version__ = "0.1.7"
Expand All @@ -11,5 +10,4 @@
"SamplingParams",
"RequestOutput",
"LLMEngine",
"EngineArgs",
]
Loading

0 comments on commit a1fec20

Please sign in to comment.