-
Notifications
You must be signed in to change notification settings - Fork 475
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3694f8f
commit fd55f89
Showing
13 changed files
with
1,855 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from sglang import LLM, SamplingParams | ||
|
||
# Sample prompts. | ||
prompts = [ | ||
"Hello, my name is", | ||
"The capital of China is", | ||
"What is the meaning of life?", | ||
"The future of AI is", | ||
] | ||
# Create a sampling params object. | ||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||
|
||
# Create an LLM. | ||
llm = LLM(model="deepseek-ai/deepseek-llm-7b-chat") | ||
|
||
outputs = llm.generate(prompts, sampling_params) | ||
# Print the outputs. | ||
for prompt, output in zip(prompts, outputs): | ||
print('===============================') | ||
print(f"Prompt: {prompt}\nGenerated text: {output['text']}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
# SGL API Components | ||
|
||
from sglang.api import ( | ||
LLM, | ||
SamplingParams, | ||
Runtime, | ||
assistant, | ||
assistant_begin, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
from dataclasses import dataclass, fields | ||
|
||
from typing import Optional, Dict, Union, List | ||
|
||
class ModelConfig: | ||
def __init__(self, | ||
model_path: str, | ||
load_format: str = "auto", | ||
tokenizer_path: Optional[str] = None, | ||
tokenizer_mode: str = "auto", | ||
skip_tokenizer_init: bool = False, | ||
dtype: str = "auto", | ||
trust_remote_code: bool = True, | ||
context_length: Optional[int] = None, | ||
quantization: Optional[str] = None, | ||
served_model_name: Optional[str] = None, | ||
random_seed: Optional[int] = None, | ||
stream_interval: int = 1, | ||
tokenizer_port: int = 0, | ||
detokenizer_port: int = 0, | ||
controller_port: int = 0, | ||
model_override_args: Optional[Dict] = None) -> None: | ||
""" | ||
ModelConfig for model and tokenizer configuration. | ||
Args: | ||
model_path: Path to the model file or directory. | ||
load_format: Format to load the model. Default is 'auto'. | ||
tokenizer_path: Path to the tokenizer file or directory. Default is None. | ||
tokenizer_mode: Mode for loading the tokenizer. Default is 'auto'. | ||
skip_tokenizer_init: Whether to skip the tokenizer initialization. Default is False. | ||
dtype: Data type for the model. Default is 'auto'. | ||
trust_remote_code: Whether to trust and execute remote code from the model repository. Default is True. | ||
context_length: Maximum context length for the model. Default is None. | ||
quantization: Quantization method. Default is None. | ||
served_model_name: Custom name for the served model. Default is None. | ||
random_seed: Seed for random number generation. Default is None. | ||
stream_interval: Interval for streaming output. Default is 1. | ||
tokenizer_port: Port number for the tokenizer. Default is 0. | ||
detokenizer_port: Port number for the detokenizer. Default is 0. | ||
controller_port: Port number for the controller. Default is 0. | ||
model_override_args: Dictionary of model override arguments. Default is None. | ||
""" | ||
self.model_path = model_path | ||
self.load_format = load_format | ||
self.tokenizer_path = tokenizer_path | ||
self.tokenizer_mode = tokenizer_mode | ||
self.skip_tokenizer_init = skip_tokenizer_init | ||
self.dtype = dtype | ||
self.trust_remote_code = trust_remote_code | ||
self.context_length = context_length | ||
self.quantization = quantization | ||
self.served_model_name = served_model_name | ||
self.random_seed = random_seed | ||
self.stream_interval = stream_interval | ||
self.tokenizer_port = tokenizer_port | ||
self.detokenizer_port = detokenizer_port | ||
self.controller_port = controller_port | ||
self.model_override_args = model_override_args | ||
|
||
def __repr__(self): | ||
return (f"ModelConfig(model_path={self.model_path}, load_format={self.load_format}, " | ||
f"tokenizer_path={self.tokenizer_path}, tokenizer_mode={self.tokenizer_mode}, " | ||
f"skip_tokenizer_init={self.skip_tokenizer_init}, dtype={self.dtype}, " | ||
f"trust_remote_code={self.trust_remote_code}, context_length={self.context_length}, " | ||
f"quantization={self.quantization}, served_model_name={self.served_model_name}, " | ||
f"random_seed={self.random_seed}, stream_interval={self.stream_interval}, " | ||
f"tokenizer_port={self.tokenizer_port}, detokenizer_port={self.detokenizer_port}, " | ||
f"controller_port={self.controller_port}, model_override_args={self.model_override_args})") | ||
|
||
class ScheduleConfig: | ||
def __init__(self, | ||
mem_fraction_static: Optional[float] = None, | ||
max_running_requests: Optional[int] = None, | ||
max_num_reqs: Optional[int] = None, | ||
max_total_tokens: Optional[int] = None, | ||
chunked_prefill_size: int = 8192, | ||
max_prefill_tokens: int = 16384, | ||
schedule_policy: str = "lpm", | ||
schedule_conservativeness: float = 1.0) -> None: | ||
""" | ||
ScheduleConfig object for scheduling and memory management | ||
Args: | ||
mem_fraction_static: Fraction of memory statically allocated. Default is None. | ||
max_running_requests: Maximum number of running requests. Default is None. | ||
max_num_reqs: Maximum number of requests. Default is None. | ||
max_total_tokens: Maximum total tokens allowed. Default is None. | ||
chunked_prefill_size: Size for chunked prefill. Default is 8192. | ||
max_prefill_tokens: Maximum tokens allowed in the prefill phase. Default is 16384. | ||
schedule_policy: Scheduling policy (e.g., 'lpm'). Default is 'lpm'. | ||
schedule_conservativeness: Conservativeness factor for scheduling. Default is 1.0. | ||
""" | ||
self.mem_fraction_static = mem_fraction_static | ||
self.max_running_requests = max_running_requests | ||
self.max_num_reqs = max_num_reqs | ||
self.max_total_tokens = max_total_tokens | ||
self.chunked_prefill_size = chunked_prefill_size | ||
self.max_prefill_tokens = max_prefill_tokens | ||
self.schedule_policy = schedule_policy | ||
self.schedule_conservativeness = schedule_conservativeness | ||
|
||
def __repr__(self): | ||
return (f"ScheduleConfig(mem_fraction_static={self.mem_fraction_static}, " | ||
f"max_running_requests={self.max_running_requests}, max_num_reqs={self.max_num_reqs}, " | ||
f"max_total_tokens={self.max_total_tokens}, chunked_prefill_size={self.chunked_prefill_size}, " | ||
f"max_prefill_tokens={self.max_prefill_tokens}, schedule_policy={self.schedule_policy}, " | ||
f"schedule_conservativeness={self.schedule_conservativeness})") | ||
|
||
class ParallelConfig: | ||
def __init__(self, | ||
tp_size: int = 1, | ||
dp_size: int = 1, | ||
load_balance_method: str = "round_robin", | ||
nccl_init_addr: Optional[str] = None, | ||
nccl_ports: List[int] = None, | ||
additional_ports: Optional[Union[List[int], int]] = None, | ||
nnodes: int = 1, | ||
node_rank: Optional[int] = None) -> None: | ||
""" | ||
ParallelConfig object for parallelism and distributed settings. | ||
Args: | ||
tp_size: Tensor parallelism size. Default is 1. | ||
dp_size: Data parallelism size. Default is 1. | ||
load_balance_method: Method for load balancing across nodes. Default is 'round_robin'. | ||
nccl_init_addr: NCCL initialization address. Default is None. | ||
nccl_ports: List of ports for NCCL communication. Default is None. | ||
additional_ports: Additional ports for distributed communication. Default is None. | ||
nnodes: Number of nodes in the distributed setup. Default is 1. | ||
node_rank: Rank of the current node. Default is None. | ||
""" | ||
self.tp_size = tp_size | ||
self.dp_size = dp_size | ||
self.load_balance_method = load_balance_method | ||
self.nccl_init_addr = nccl_init_addr | ||
self.nccl_ports = nccl_ports | ||
self.additional_ports = additional_ports | ||
self.nnodes = nnodes | ||
self.node_rank = node_rank | ||
|
||
def __repr__(self): | ||
return (f"ParallelConfig(tp_size={self.tp_size}, dp_size={self.dp_size}, " | ||
f"load_balance_method={self.load_balance_method}, nccl_init_addr={self.nccl_init_addr}, " | ||
f"nccl_ports={self.nccl_ports}, additional_ports={self.additional_ports}, " | ||
f"nnodes={self.nnodes}, node_rank={self.node_rank})") | ||
|
||
class OptimizationConfig: | ||
def __init__(self, | ||
disable_flashinfer: bool = False, | ||
disable_flashinfer_sampling: bool = False, | ||
disable_radix_cache: bool = False, | ||
disable_regex_jump_forward: bool = False, | ||
disable_cuda_graph: bool = False, | ||
disable_disk_cache: bool = False, | ||
enable_torch_compile: bool = False, | ||
enable_p2p_check: bool = False, | ||
enable_mla: bool = False, | ||
attention_reduce_in_fp32: bool = False, | ||
efficient_weight_load: bool = False) -> None: | ||
""" | ||
OptimizationConfig object for optimization and debug options | ||
Args: | ||
disable_flashinfer: Disable flashinfer library. Default is False. | ||
disable_flashinfer_sampling: Disable flashinfer sampling. Default is False. | ||
disable_radix_cache: Disable radix cache optimization. Default is False. | ||
disable_regex_jump_forward: Disable regex-based jump forward optimization. Default is False. | ||
disable_cuda_graph: Disable CUDA graph optimization. Default is False. | ||
disable_disk_cache: Disable disk caching. Default is False. | ||
enable_torch_compile: Enable PyTorch compilation optimization. Default is False. | ||
enable_p2p_check: Enable peer-to-peer communication checks. Default is False. | ||
enable_mla: Enable Multi-Head Latent Attention from DeepSeek-V2. Default is False. | ||
attention_reduce_in_fp32: Perform attention reduction in FP32 precision. Default is False. | ||
efficient_weight_load: Enable efficient weight loading. Default is False. | ||
""" | ||
self.disable_flashinfer = disable_flashinfer | ||
self.disable_flashinfer_sampling = disable_flashinfer_sampling | ||
self.disable_radix_cache = disable_radix_cache | ||
self.disable_regex_jump_forward = disable_regex_jump_forward | ||
self.disable_cuda_graph = disable_cuda_graph | ||
self.disable_disk_cache = disable_disk_cache | ||
self.enable_torch_compile = enable_torch_compile | ||
self.enable_p2p_check = enable_p2p_check | ||
self.enable_mla = enable_mla | ||
self.attention_reduce_in_fp32 = attention_reduce_in_fp32 | ||
self.efficient_weight_load = efficient_weight_load | ||
|
||
def __repr__(self): | ||
return (f"OptimizationConfig(disable_flashinfer={self.disable_flashinfer}, " | ||
f"disable_flashinfer_sampling={self.disable_flashinfer_sampling}, disable_radix_cache={self.disable_radix_cache}, " | ||
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, disable_cuda_graph={self.disable_cuda_graph}, " | ||
f"disable_disk_cache={self.disable_disk_cache}, enable_torch_compile={self.enable_torch_compile}, " | ||
f"enable_p2p_check={self.enable_p2p_check}, enable_mla={self.enable_mla}, " | ||
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, efficient_weight_load={self.efficient_weight_load})") | ||
|
||
|
||
@dataclass(frozen=True) | ||
class EngineConfig: | ||
"""Dataclass which contains all engine-related configuration. This | ||
simplifies passing around the distinct configurations in the codebase. | ||
""" | ||
model_config: ModelConfig | ||
schedule_config: ScheduleConfig | ||
parallel_config: ParallelConfig | ||
optimization_config: OptimizationConfig | ||
|
||
def __post_init__(self): | ||
"""Verify configs are valid & consistent with each other. | ||
""" | ||
# TODO: Do validation | ||
pass | ||
|
||
def to_dict(self): | ||
"""Return the configs as a dictionary, for use in **kwargs. | ||
""" | ||
return dict( | ||
(field.name, getattr(self, field.name)) for field in fields(self)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.