diff --git a/.github/workflows/e2e-test.yml b/.github/workflows/e2e-test.yml index ad271c37edb..a89ae19ea5f 100644 --- a/.github/workflows/e2e-test.yml +++ b/.github/workflows/e2e-test.yml @@ -35,6 +35,14 @@ jobs: pip install -e "python[all]" pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall + - name: Set PYTHONPATH + run: | + echo "PYTHONPATH=$PYTHONPATH:$(pwd)/python" >> $GITHUB_ENV + + - name: Verify import + run: | + python3 -c "import sglang.srt.serving" + - name: Benchmark Serving Throughput run: | cd test/srt diff --git a/.gitignore b/.gitignore index ca43e1ccba4..15e29a02f76 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ coverage.xml .hypothesis/ .pytest_cache/ cover/ +human-eval/ # Translations *.mo diff --git a/examples/usage/llava_video/srt_example_llava_v.py b/examples/usage/llava_video/srt_example_llava_v.py index 27ba862d30d..3f1998e9c0e 100644 --- a/examples/usage/llava_video/srt_example_llava_v.py +++ b/examples/usage/llava_video/srt_example_llava_v.py @@ -184,20 +184,20 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= print("Invalid model path. Please specify a valid model path.") exit() - model_overide_args = {} + model_override_args = {} - model_overide_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride - model_overide_args["architectures"] = ["LlavaVidForCausalLM"] - model_overide_args["num_frames"] = args.num_frames - model_overide_args["model_type"] = "llava" + model_override_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride + model_override_args["architectures"] = ["LlavaVidForCausalLM"] + model_override_args["num_frames"] = args.num_frames + model_override_args["model_type"] = "llava" if "34b" in args.model_path.lower(): - model_overide_args["image_token_index"] = 64002 + model_override_args["image_token_index"] = 64002 if args.num_frames == 32: - model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"} - model_overide_args["max_sequence_length"] = 4096 * 2 - model_overide_args["tokenizer_model_max_length"] = 4096 * 2 + model_override_args["rope_scaling"] = {"factor": 2.0, "type": "linear"} + model_override_args["max_sequence_length"] = 4096 * 2 + model_override_args["tokenizer_model_max_length"] = 4096 * 2 elif args.num_frames < 32: pass else: @@ -211,7 +211,7 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= tokenizer_path=tokenizer_path, port=cur_port, additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4], - model_overide_args=model_overide_args, + model_override_args=model_override_args, tp_size=1, ) sgl.set_default_backend(runtime) diff --git a/examples/usage/llm_engine.py b/examples/usage/llm_engine.py index a277d93f9de..dc564e0ba17 100644 --- a/examples/usage/llm_engine.py +++ b/examples/usage/llm_engine.py @@ -16,5 +16,5 @@ outputs = llm.generate(prompts, sampling_params) # Print the outputs. for prompt, output in zip(prompts, outputs): - print('===============================') - print(f"Prompt: {prompt}\nGenerated text: {output['text']}") \ No newline at end of file + print("===============================") + print(f"Prompt: {prompt}\nGenerated text: {output['text']}") diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index 26d5d099b4d..e86c12faf37 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -2,8 +2,8 @@ from sglang.api import ( LLM, - SamplingParams, Runtime, + SamplingParams, assistant, assistant_begin, assistant_end, diff --git a/python/sglang/api.py b/python/sglang/api.py index fc9ada1d5c9..7d30c9f6340 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -18,8 +18,9 @@ SglSelect, SglVideo, ) -from sglang.srt.serving.engine import LLM from sglang.srt.sampling_params import SamplingParams +from sglang.srt.serving.engine import LLM + def function( func: Optional[Callable] = None, num_api_spec_tokens: Optional[int] = None @@ -36,7 +37,7 @@ def decorator(func): def Runtime(*args, **kwargs): # Avoid importing unnecessary dependency os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - from sglang.srt.server import Runtime + from sglang.srt.serving.server import Runtime return Runtime(*args, **kwargs) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index dd86747e366..f42dcfe361b 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -49,13 +49,18 @@ import torch import torch.distributed as dist +from sglang.srt.config import ( + ModelConfig, + OptimizationConfig, + ParallelConfig, + ScheduleConfig, +) from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.schedule_batch import Req, ScheduleBatch -from sglang.srt.model_config import ModelConfig from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling_params import SamplingParams -from sglang.srt.server_args import ServerArgs +from sglang.srt.serving.server_args import ServerArgs from sglang.srt.utils import suppress_other_loggers @@ -111,15 +116,19 @@ def load_model(server_args, tp_rank): suppress_other_loggers() rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None - model_config = ModelConfig(path=server_args.model_path) + model_config = ModelConfig(model_path=server_args.model_path) + optimization_config = OptimizationConfig() + parallel_config = ParallelConfig(tp_size=server_args.tp_size, nccl_ports=[28888]) + schedule_config = ScheduleConfig( + mem_fraction_static=server_args.mem_fraction_static + ) model_runner = ModelRunner( model_config=model_config, - mem_fraction_static=server_args.mem_fraction_static, + optimization_config=optimization_config, + parallel_config=parallel_config, + schedule_config=schedule_config, gpu_id=tp_rank, tp_rank=tp_rank, - tp_size=server_args.tp_size, - nccl_port=28888, - server_args=server_args, ) rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}") tokenizer = get_tokenizer( diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 91dc0dc4e95..840461e54a5 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -2,8 +2,8 @@ import argparse -from sglang.srt.server import launch_server -from sglang.srt.server_args import ServerArgs +from sglang.srt.serving.server import launch_server +from sglang.srt.serving.server_args import ServerArgs if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/python/sglang/launch_server_llavavid.py b/python/sglang/launch_server_llavavid.py index c34dd211672..af9415a984c 100644 --- a/python/sglang/launch_server_llavavid.py +++ b/python/sglang/launch_server_llavavid.py @@ -2,28 +2,29 @@ import argparse -from sglang.srt.server import ServerArgs, launch_server +from sglang.srt.serving.server import ServerArgs, launch_server if __name__ == "__main__": - model_overide_args = {} - - model_overide_args["mm_spatial_pool_stride"] = 2 - model_overide_args["architectures"] = ["LlavaVidForCausalLM"] - model_overide_args["num_frames"] = 16 - model_overide_args["model_type"] = "llavavid" - if model_overide_args["num_frames"] == 32: - model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"} - model_overide_args["max_sequence_length"] = 4096 * 2 - model_overide_args["tokenizer_model_max_length"] = 4096 * 2 - model_overide_args["model_max_length"] = 4096 * 2 + model_override_args = {} + + model_override_args["mm_spatial_pool_stride"] = 2 + model_override_args["architectures"] = ["LlavaVidForCausalLM"] + model_override_args["num_frames"] = 16 + model_override_args["model_type"] = "llavavid" + if model_override_args["num_frames"] == 32: + model_override_args["rope_scaling"] = {"factor": 2.0, "type": "linear"} + model_override_args["max_sequence_length"] = 4096 * 2 + model_override_args["tokenizer_model_max_length"] = 4096 * 2 + model_override_args["model_max_length"] = 4096 * 2 parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) args = parser.parse_args() if "34b" in args.model_path.lower(): - model_overide_args["image_token_index"] = 64002 + model_override_args["image_token_index"] = 64002 server_args = ServerArgs.from_cli_args(args) + server_args.model_override_args = model_override_args - launch_server(server_args, model_overide_args, None) + launch_server(server_args, None) diff --git a/python/sglang/srt/__init__.py b/python/sglang/srt/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/sglang/srt/config.py b/python/sglang/srt/config.py index 063aaef6d28..1cc9a27583f 100644 --- a/python/sglang/srt/config.py +++ b/python/sglang/srt/config.py @@ -1,25 +1,38 @@ from dataclasses import dataclass, fields +from enum import IntEnum, auto +from typing import Dict, List, Optional, Union + +from transformers import PretrainedConfig + +from sglang.srt.hf_transformers_utils import get_config, get_context_length + + +class AttentionArch(IntEnum): + MLA = auto() + MHA = auto() -from typing import Optional, Dict, Union, List class ModelConfig: - def __init__(self, - model_path: str, - load_format: str = "auto", - tokenizer_path: Optional[str] = None, - tokenizer_mode: str = "auto", - skip_tokenizer_init: bool = False, - dtype: str = "auto", - trust_remote_code: bool = True, - context_length: Optional[int] = None, - quantization: Optional[str] = None, - served_model_name: Optional[str] = None, - random_seed: Optional[int] = None, - stream_interval: int = 1, - tokenizer_port: int = 0, - detokenizer_port: int = 0, - controller_port: int = 0, - model_override_args: Optional[Dict] = None) -> None: + def __init__( + self, + model_path: str, + load_format: str = "auto", + tokenizer_path: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + dtype: str = "auto", + trust_remote_code: bool = True, + revision: Optional[str] = None, + context_length: Optional[int] = None, + quantization: Optional[str] = None, + served_model_name: Optional[str] = None, + random_seed: Optional[int] = None, + stream_interval: int = 1, + tokenizer_port: int = 0, + detokenizer_port: int = 0, + controller_port: int = 0, + model_override_args: Optional[Dict] = None, + ) -> None: """ ModelConfig for model and tokenizer configuration. @@ -31,6 +44,7 @@ def __init__(self, skip_tokenizer_init: Whether to skip the tokenizer initialization. Default is False. dtype: Data type for the model. Default is 'auto'. trust_remote_code: Whether to trust and execute remote code from the model repository. Default is True. + revision: model revision str. Default is None, context_length: Maximum context length for the model. Default is None. quantization: Quantization method. Default is None. served_model_name: Custom name for the served model. Default is None. @@ -48,7 +62,7 @@ def __init__(self, self.skip_tokenizer_init = skip_tokenizer_init self.dtype = dtype self.trust_remote_code = trust_remote_code - self.context_length = context_length + self.revision = revision self.quantization = quantization self.served_model_name = served_model_name self.random_seed = random_seed @@ -58,28 +72,154 @@ def __init__(self, self.controller_port = controller_port self.model_override_args = model_override_args + self.hf_config = get_config( + self.model_path, + trust_remote_code, + revision, + model_override_args=model_override_args, + ) + self.hf_text_config = self.get_hf_text_config(self.hf_config) + if context_length is not None: + self.context_length = context_length + else: + self.context_length = get_context_length(self.hf_config) + + # Unify the config keys for hf_config + self.head_dim = getattr( + self.hf_config, + "head_dim", + self.hf_config.hidden_size // self.hf_config.num_attention_heads, + ) + + # FIXME: temporary special judge for deepseek v2 MLA architecture + if "DeepseekV2ForCausalLM" in self.hf_config.architectures: + self.head_dim = 256 + self.attention_arch = AttentionArch.MLA + self.kv_lora_rank = self.hf_config.kv_lora_rank + self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim + else: + self.attention_arch = AttentionArch.MHA + + self.num_attention_heads = self.hf_config.num_attention_heads + self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None) + + # for Dbrx and MPT models + if self.hf_config.model_type in ["dbrx", "mpt"]: + self.num_key_value_heads = getattr( + self.hf_config.attn_config, "kv_n_heads", None + ) + + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + self.hidden_size = self.hf_config.hidden_size + self.num_hidden_layers = self.hf_config.num_hidden_layers + self.vocab_size = self.hf_config.vocab_size + def __repr__(self): - return (f"ModelConfig(model_path={self.model_path}, load_format={self.load_format}, " - f"tokenizer_path={self.tokenizer_path}, tokenizer_mode={self.tokenizer_mode}, " - f"skip_tokenizer_init={self.skip_tokenizer_init}, dtype={self.dtype}, " - f"trust_remote_code={self.trust_remote_code}, context_length={self.context_length}, " - f"quantization={self.quantization}, served_model_name={self.served_model_name}, " - f"random_seed={self.random_seed}, stream_interval={self.stream_interval}, " - f"tokenizer_port={self.tokenizer_port}, detokenizer_port={self.detokenizer_port}, " - f"controller_port={self.controller_port}, model_override_args={self.model_override_args})") + return ( + f"ModelConfig(model_path={self.model_path}, load_format={self.load_format}, " + f"tokenizer_path={self.tokenizer_path}, tokenizer_mode={self.tokenizer_mode}, " + f"skip_tokenizer_init={self.skip_tokenizer_init}, dtype={self.dtype}, " + f"trust_remote_code={self.trust_remote_code}, revision={self.revision}, context_length={self.context_length}, " + f"quantization={self.quantization}, served_model_name={self.served_model_name}, " + f"random_seed={self.random_seed}, stream_interval={self.stream_interval}, " + f"tokenizer_port={self.tokenizer_port}, detokenizer_port={self.detokenizer_port}, " + f"controller_port={self.controller_port}, model_override_args={self.model_override_args})" + ) + + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 + def get_total_num_kv_heads(self) -> int: + """Returns the total number of KV heads.""" + # For GPTBigCode & Falcon: + # NOTE: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] + new_decoder_arch_falcon = ( + self.hf_config.model_type in falcon_model_types + and getattr(self.hf_config, "new_decoder_architecture", False) + ) + if not new_decoder_arch_falcon and getattr( + self.hf_text_config, "multi_query", False + ): + # Multi-query attention, only one KV head. + # Currently, tensor parallelism is not supported in this case. + return 1 + + # For DBRX and MPT + if self.hf_config.model_type in ["mpt"]: + if "kv_n_heads" in self.hf_config.attn_config: + return self.hf_config.attn_config["kv_n_heads"] + return self.hf_config.num_attention_heads + if self.hf_config.model_type in ["dbrx"]: + return getattr( + self.hf_config.attn_config, + "kv_n_heads", + self.hf_config.num_attention_heads, + ) + + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_text_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + # For non-grouped-query attention models, the number of KV heads is + # equal to the number of attention heads. + return self.hf_text_config.num_attention_heads + + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328 + def get_num_kv_heads(self, tensor_parallel_size) -> int: + """Returns the number of KV heads per GPU.""" + total_num_kv_heads = self.get_total_num_kv_heads() + # If tensor parallelism is used, we divide the number of KV heads by + # the tensor parallel size. We will replicate the KV heads in the + # case where the number of KV heads is smaller than the tensor + # parallel size so each GPU has at least one KV head. + return max(1, total_num_kv_heads // tensor_parallel_size) + + def get_hf_text_config(self, config: PretrainedConfig): + """Get the "sub" config relevant to llm for multi modal models. + No op for pure text models. + """ + class_name = config.architectures[0] + if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"): + # We support non-hf version of llava models, so we do not want to + # read the wrong values from the unused default text_config. + return config + + if hasattr(config, "text_config"): + # The code operates under the assumption that text_config should have + # `num_attention_heads` (among others). Assert here to fail early + # if transformers config doesn't align with this assumption. + assert hasattr(config.text_config, "num_attention_heads") + return config.text_config + else: + return config + class ScheduleConfig: - def __init__(self, - mem_fraction_static: Optional[float] = None, - max_running_requests: Optional[int] = None, - max_num_reqs: Optional[int] = None, - max_total_tokens: Optional[int] = None, - chunked_prefill_size: int = 8192, - max_prefill_tokens: int = 16384, - schedule_policy: str = "lpm", - schedule_conservativeness: float = 1.0) -> None: + def __init__( + self, + mem_fraction_static: Optional[float] = None, + max_running_requests: Optional[int] = None, + max_num_reqs: Optional[int] = None, + max_total_tokens: Optional[int] = None, + chunked_prefill_size: int = 8192, + max_prefill_tokens: int = 16384, + schedule_policy: str = "lpm", + schedule_conservativeness: float = 1.0, + ) -> None: """ - ScheduleConfig object for scheduling and memory management + ScheduleConfig object for scheduling and memory management Args: mem_fraction_static: Fraction of memory statically allocated. Default is None. @@ -101,22 +241,27 @@ def __init__(self, self.schedule_conservativeness = schedule_conservativeness def __repr__(self): - return (f"ScheduleConfig(mem_fraction_static={self.mem_fraction_static}, " - f"max_running_requests={self.max_running_requests}, max_num_reqs={self.max_num_reqs}, " - f"max_total_tokens={self.max_total_tokens}, chunked_prefill_size={self.chunked_prefill_size}, " - f"max_prefill_tokens={self.max_prefill_tokens}, schedule_policy={self.schedule_policy}, " - f"schedule_conservativeness={self.schedule_conservativeness})") + return ( + f"ScheduleConfig(mem_fraction_static={self.mem_fraction_static}, " + f"max_running_requests={self.max_running_requests}, max_num_reqs={self.max_num_reqs}, " + f"max_total_tokens={self.max_total_tokens}, chunked_prefill_size={self.chunked_prefill_size}, " + f"max_prefill_tokens={self.max_prefill_tokens}, schedule_policy={self.schedule_policy}, " + f"schedule_conservativeness={self.schedule_conservativeness})" + ) + class ParallelConfig: - def __init__(self, - tp_size: int = 1, - dp_size: int = 1, - load_balance_method: str = "round_robin", - nccl_init_addr: Optional[str] = None, - nccl_ports: List[int] = None, - additional_ports: Optional[Union[List[int], int]] = None, - nnodes: int = 1, - node_rank: Optional[int] = None) -> None: + def __init__( + self, + tp_size: int = 1, + dp_size: int = 1, + load_balance_method: str = "round_robin", + nccl_init_addr: Optional[str] = None, + nccl_ports: List[int] = None, + additional_ports: Optional[Union[List[int], int]] = None, + nnodes: int = 1, + node_rank: Optional[int] = None, + ) -> None: """ ParallelConfig object for parallelism and distributed settings. @@ -140,24 +285,30 @@ def __init__(self, self.node_rank = node_rank def __repr__(self): - return (f"ParallelConfig(tp_size={self.tp_size}, dp_size={self.dp_size}, " - f"load_balance_method={self.load_balance_method}, nccl_init_addr={self.nccl_init_addr}, " - f"nccl_ports={self.nccl_ports}, additional_ports={self.additional_ports}, " - f"nnodes={self.nnodes}, node_rank={self.node_rank})") + return ( + f"ParallelConfig(tp_size={self.tp_size}, dp_size={self.dp_size}, " + f"load_balance_method={self.load_balance_method}, nccl_init_addr={self.nccl_init_addr}, " + f"nccl_ports={self.nccl_ports}, additional_ports={self.additional_ports}, " + f"nnodes={self.nnodes}, node_rank={self.node_rank})" + ) + class OptimizationConfig: - def __init__(self, - disable_flashinfer: bool = False, - disable_flashinfer_sampling: bool = False, - disable_radix_cache: bool = False, - disable_regex_jump_forward: bool = False, - disable_cuda_graph: bool = False, - disable_disk_cache: bool = False, - enable_torch_compile: bool = False, - enable_p2p_check: bool = False, - enable_mla: bool = False, - attention_reduce_in_fp32: bool = False, - efficient_weight_load: bool = False) -> None: + def __init__( + self, + disable_flashinfer: bool = False, + disable_flashinfer_sampling: bool = False, + disable_radix_cache: bool = False, + disable_regex_jump_forward: bool = False, + disable_cuda_graph: bool = False, + disable_disk_cache: bool = False, + enable_torch_compile: bool = False, + enable_p2p_check: bool = False, + enable_mixed_chunk: bool = False, + enable_mla: bool = False, + attention_reduce_in_fp32: bool = False, + efficient_weight_load: bool = False, + ) -> None: """ OptimizationConfig object for optimization and debug options @@ -170,6 +321,7 @@ def __init__(self, disable_disk_cache: Disable disk caching. Default is False. enable_torch_compile: Enable PyTorch compilation optimization. Default is False. enable_p2p_check: Enable peer-to-peer communication checks. Default is False. + enable_mixed_chunk: Enable mixed chunk. Default is False. enable_mla: Enable Multi-Head Latent Attention from DeepSeek-V2. Default is False. attention_reduce_in_fp32: Perform attention reduction in FP32 precision. Default is False. efficient_weight_load: Enable efficient weight loading. Default is False. @@ -182,17 +334,52 @@ def __init__(self, self.disable_disk_cache = disable_disk_cache self.enable_torch_compile = enable_torch_compile self.enable_p2p_check = enable_p2p_check + self.enable_mixed_chunk = enable_mixed_chunk self.enable_mla = enable_mla self.attention_reduce_in_fp32 = attention_reduce_in_fp32 self.efficient_weight_load = efficient_weight_load def __repr__(self): - return (f"OptimizationConfig(disable_flashinfer={self.disable_flashinfer}, " - f"disable_flashinfer_sampling={self.disable_flashinfer_sampling}, disable_radix_cache={self.disable_radix_cache}, " - f"disable_regex_jump_forward={self.disable_regex_jump_forward}, disable_cuda_graph={self.disable_cuda_graph}, " - f"disable_disk_cache={self.disable_disk_cache}, enable_torch_compile={self.enable_torch_compile}, " - f"enable_p2p_check={self.enable_p2p_check}, enable_mla={self.enable_mla}, " - f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, efficient_weight_load={self.efficient_weight_load})") + return ( + f"OptimizationConfig(disable_flashinfer={self.disable_flashinfer}, " + f"disable_flashinfer_sampling={self.disable_flashinfer_sampling}, disable_radix_cache={self.disable_radix_cache}, " + f"disable_regex_jump_forward={self.disable_regex_jump_forward}, disable_cuda_graph={self.disable_cuda_graph}, " + f"disable_disk_cache={self.disable_disk_cache}, enable_torch_compile={self.enable_torch_compile}, " + f"enable_p2p_check={self.enable_p2p_check}, enable_mixed_chunk={self.enable_mixed_chunk}, enable_mla={self.enable_mla}, " + f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, efficient_weight_load={self.efficient_weight_load})" + ) + + +class ObservabilityConfig: + """ + ObservabilityConfig object for log and observability settings. + + Args: + log_level (str): Log level, default is "info". + log_level_http (Optional[str]): Log level for HTTP server, default is None. + log_requests (bool): Whether to log requests, default is False. + show_time_cost (bool): Whether to enable show time cost for debugging, default is False. + """ + + def __init__( + self, + log_level: str = "info", + log_level_http: Optional[str] = None, + log_requests: bool = False, + show_time_cost: bool = False, + ): + self.log_level = log_level + self.log_level_http = log_level_http + self.log_requests = log_requests + self.show_time_cost = show_time_cost + + def __repr__(self) -> str: + return ( + f"ObservabilityConfig(log_level={self.log_level}, " + f"log_level_http={self.log_level_http}, " + f"log_requests={self.log_requests}), " + f"show_time_cost={self.show_time_cost}" + ) @dataclass(frozen=True) @@ -200,19 +387,18 @@ class EngineConfig: """Dataclass which contains all engine-related configuration. This simplifies passing around the distinct configurations in the codebase. """ + model_config: ModelConfig schedule_config: ScheduleConfig parallel_config: ParallelConfig optimization_config: OptimizationConfig + observability_config: ObservabilityConfig def __post_init__(self): - """Verify configs are valid & consistent with each other. - """ - # TODO: Do validation + """Verify configs are valid & consistent with each other.""" + # TODO: Do validation for each *Config pass def to_dict(self): - """Return the configs as a dictionary, for use in **kwargs. - """ - return dict( - (field.name, getattr(self, field.name)) for field in fields(self)) \ No newline at end of file + """Return the configs as a dictionary, for use in **kwargs.""" + return dict((field.name, getattr(self, field.name)) for field in fields(self)) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index fb198fd73ca..8f8f3ef73d5 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -62,7 +62,7 @@ def get_config( model: str, trust_remote_code: bool, revision: Optional[str] = None, - model_overide_args: Optional[dict] = None, + model_override_args: Optional[dict] = None, ): config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision @@ -70,8 +70,8 @@ def get_config( if config.model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[config.model_type] config = config_class.from_pretrained(model, revision=revision) - if model_overide_args: - config.update(model_overide_args) + if model_override_args: + config.update(model_override_args) return config diff --git a/python/sglang/srt/managers/controller_multi.py b/python/sglang/srt/managers/controller_multi.py index adbe40b9ba2..0dd92412c85 100644 --- a/python/sglang/srt/managers/controller_multi.py +++ b/python/sglang/srt/managers/controller_multi.py @@ -28,6 +28,13 @@ import numpy as np import zmq +from sglang.srt.config import ( + ModelConfig, + ObservabilityConfig, + OptimizationConfig, + ParallelConfig, + ScheduleConfig, +) from sglang.srt.managers.controller_single import ( start_controller_process as start_controller_process_single, ) @@ -36,7 +43,6 @@ FlushCacheReq, TokenizedGenerateReqInput, ) -from sglang.srt.server_args import ServerArgs from sglang.srt.utils import kill_parent_process from sglang.utils import get_exception_traceback @@ -71,26 +77,27 @@ class ControllerMulti: def __init__( self, - server_args: ServerArgs, - controller_port: int, - detokenizer_port: int, - nccl_ports: List[int], - model_overide_args, + model_config: ModelConfig, + parallel_config: ParallelConfig, + schedule_config: ScheduleConfig, + optimization_config: OptimizationConfig, + observability_config: ObservabilityConfig, ): - # Parse args - self.server_args = server_args - self.controller_port = controller_port - self.detokenizer_port = detokenizer_port - self.nccl_ports = nccl_ports - self.model_overide_args = model_overide_args self.load_balance_method = LoadBalanceMethod.from_str( - server_args.load_balance_method + parallel_config.load_balance_method ) + self.model_config = model_config + self.parallel_config = parallel_config + self.schedule_config = schedule_config + self.optimization_config = optimization_config + self.observability_config = observability_config # Init communication context = zmq.Context() self.recv_from_tokenizer = context.socket(zmq.PULL) - self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{controller_port}") + self.recv_from_tokenizer.bind( + f"tcp://127.0.0.1:{self.model_config.controller_port}" + ) # Dispatch method self.round_robin_counter = 0 @@ -102,11 +109,11 @@ def __init__( # Start data parallel workers self.workers = [] - for i in range(server_args.dp_size): + for i in range(parallel_config.dp_size): self.start_dp_worker(i) def start_dp_worker(self, dp_worker_id: int): - tp_size = self.server_args.tp_size + tp_size = self.tp_size pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe( duplex=False @@ -117,12 +124,12 @@ def start_dp_worker(self, dp_worker_id: int): proc = multiprocessing.Process( target=start_controller_process_single, args=( - self.server_args, - self.controller_port, - self.detokenizer_port, - self.nccl_ports, + self.model_config, + self.parallel_config, + self.schedule_config, + self.optimization_config, + self.observability_config, pipe_controller_writer, - self.model_overide_args, True, gpu_ids, dp_worker_id, @@ -194,23 +201,28 @@ def recv_requests(self): def start_controller_process( - server_args: ServerArgs, - controller_port: int, - detokenizer_port: int, - nccl_ports: List[int], - pipe_writer, - model_overide_args: dict, + model_config: ModelConfig, + parallel_config: ParallelConfig, + schedule_config: ScheduleConfig, + optimization_config: OptimizationConfig, + observability_config: ObservabilityConfig, + pipe_writer: multiprocessing.connection.Connection, ): """Start a controller process.""" logging.basicConfig( - level=getattr(logging, server_args.log_level.upper()), + level=getattr(logging, observability_config.log_level.upper()), format="%(message)s", ) try: - controller = ControllerMulti(server_args, controller_port, detokenizer_port, - nccl_ports, model_overide_args) + controller = ControllerMulti( + model_config, + parallel_config, + schedule_config, + optimization_config, + observability_config, + ) except Exception: pipe_writer.send(get_exception_traceback()) raise diff --git a/python/sglang/srt/managers/controller_single.py b/python/sglang/srt/managers/controller_single.py index a4b57073a33..fbe7967dfab 100644 --- a/python/sglang/srt/managers/controller_single.py +++ b/python/sglang/srt/managers/controller_single.py @@ -22,12 +22,18 @@ import zmq +from sglang.srt.config import ( + ModelConfig, + ObservabilityConfig, + OptimizationConfig, + ParallelConfig, + ScheduleConfig, +) from sglang.srt.managers.tp_worker import ( ModelTpServer, broadcast_recv_input, launch_tp_servers, ) -from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import kill_parent_process from sglang.utils import get_exception_traceback @@ -39,18 +45,17 @@ class ControllerSingle: def __init__( self, - server_args: ServerArgs, - controller_port: int, - detokenizer_port: int, - nccl_ports: List[int], - model_overide_args: dict, + model_config: ModelConfig, + parallel_config: ParallelConfig, + schedule_config: ScheduleConfig, + optimization_config: OptimizationConfig, gpu_ids: List[int], is_data_parallel_worker: bool, dp_worker_id: int, mp_queue: multiprocessing.Queue, ): # Parse args - self.tp_size = server_args.tp_size + self.tp_size = parallel_config.tp_size self.is_dp_worker = is_data_parallel_worker self.dp_worker_id = dp_worker_id self.mp_queue = mp_queue @@ -61,35 +66,39 @@ def __init__( if not self.is_dp_worker: self.recv_from_tokenizer = context.socket(zmq.PULL) self.recv_from_tokenizer.bind( - f"tcp://127.0.0.1:{controller_port}" + f"tcp://127.0.0.1:{model_config.controller_port}" + ) + logging.info( + f"ZeroMQ PULL socket created and binding to tcp://127.0.0.1:{model_config.controller_port}" ) - logging.info(f'ZeroMQ PULL socket created and binding to tcp://127.0.0.1:{controller_port}') self.send_to_detokenizer = context.socket(zmq.PUSH) self.send_to_detokenizer.connect( - f"tcp://127.0.0.1:{detokenizer_port}" + f"tcp://127.0.0.1:{model_config.detokenizer_port}" ) # Launch other tp ranks - tp_size_local = server_args.tp_size // server_args.nnodes + tp_size_local = parallel_config.tp_size // parallel_config.nnodes self.tp_procs = [] if tp_size_local > 1: tp_rank_range = range(1, tp_size_local) self.tp_procs = launch_tp_servers( gpu_ids, tp_rank_range, - server_args, - nccl_ports[dp_worker_id], - model_overide_args, + model_config, + parallel_config, + schedule_config, + optimization_config, ) # Launch tp rank 0 self.tp_server = ModelTpServer( gpu_ids[0], 0, - server_args, - nccl_ports[dp_worker_id], - model_overide_args, + model_config, + parallel_config, + schedule_config, + optimization_config, ) self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group @@ -127,12 +136,12 @@ def recv_requests_from_mp_queue(self): def start_controller_process( - server_args: ServerArgs, - controller_port: int, - detokenizer_port: int, - nccl_ports: List[int], + model_config: ModelConfig, + parallel_config: ParallelConfig, + schedule_config: ScheduleConfig, + optimization_config: OptimizationConfig, + observability_config: ObservabilityConfig, pipe_writer: multiprocessing.connection.Connection, - model_overide_args: dict, is_data_parallel_worker: bool = False, gpu_ids: List[int] = None, dp_worker_id: int = None, @@ -141,23 +150,24 @@ def start_controller_process( """Start a controller process.""" logging.basicConfig( - level=getattr(logging, server_args.log_level.upper()), + level=getattr(logging, observability_config.log_level.upper()), format="%(message)s", ) if not is_data_parallel_worker: - tp_size_local = server_args.tp_size // server_args.nnodes - gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)] + tp_size_local = parallel_config.tp_size // parallel_config.nnodes + gpu_ids = [ + i for _ in range(parallel_config.nnodes) for i in range(tp_size_local) + ] dp_worker_id = 0 queue = None try: controller = ControllerSingle( - server_args, - controller_port, - detokenizer_port, - nccl_ports, - model_overide_args, + model_config, + parallel_config, + schedule_config, + optimization_config, gpu_ids, is_data_parallel_worker, dp_worker_id, diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index b90da67e768..51627a7c99b 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -23,6 +23,7 @@ import zmq import zmq.asyncio +from sglang.srt.config import ModelConfig from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.io_struct import ( BatchEmbeddingOut, @@ -30,7 +31,6 @@ BatchTokenIDOut, ) from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR -from sglang.srt.server_args import PortArgs, ServerArgs from sglang.utils import find_printable_text, get_exception_traceback asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -48,24 +48,22 @@ class DecodeStatus: class DetokenizerManager: def __init__( self, - server_args: ServerArgs, - tokenizer_port: int, - detokenizer_port: int, + model_config: ModelConfig, ): context = zmq.asyncio.Context(2) self.recv_from_router = context.socket(zmq.PULL) - self.recv_from_router.bind(f"tcp://127.0.0.1:{detokenizer_port}") + self.recv_from_router.bind(f"tcp://127.0.0.1:{model_config.detokenizer_port}") self.send_to_tokenizer = context.socket(zmq.PUSH) - self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{tokenizer_port}") + self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{model_config.tokenizer_port}") - if server_args.skip_tokenizer_init: + if model_config.skip_tokenizer_init: self.tokenizer = None else: self.tokenizer = get_tokenizer( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, + model_config.tokenizer_path, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, ) self.decode_status = {} @@ -159,13 +157,11 @@ async def handle_loop(self): def start_detokenizer_process( - server_args: ServerArgs, - tokenizer_port: int, - detokenizer_port: int, + model_config: ModelConfig, pipe_writer, ): try: - manager = DetokenizerManager(server_args, tokenizer_port, detokenizer_port) + manager = DetokenizerManager(model_config) except Exception: pipe_writer.send(get_exception_traceback()) raise diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 3b08789d92a..d50c4b04b28 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -30,6 +30,7 @@ import zmq.asyncio from fastapi import BackgroundTasks +from sglang.srt.config import ModelConfig, ObservabilityConfig from sglang.srt.hf_transformers_utils import ( get_config, get_context_length, @@ -49,7 +50,6 @@ ) from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.sampling_params import SamplingParams -from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image from sglang.utils import get_exception_traceback @@ -68,61 +68,60 @@ class ReqState: class TokenizerManager: def __init__( self, - server_args: ServerArgs, - tokenizer_port: int, - controller_port: int, - model_overide_args: dict = None, + model_config: ModelConfig, + observability_config: ObservabilityConfig, ): - self.server_args = server_args - context = zmq.asyncio.Context(2) self.recv_from_detokenizer = context.socket(zmq.PULL) - self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{tokenizer_port}") + self.recv_from_detokenizer.bind( + f"tcp://127.0.0.1:{model_config.tokenizer_port}" + ) self.send_to_router = context.socket(zmq.PUSH) - self.send_to_router.connect(f"tcp://127.0.0.1:{controller_port}") + self.send_to_router.connect(f"tcp://127.0.0.1:{model_config.controller_port}") - self.model_path = server_args.model_path - self.served_model_name = server_args.served_model_name + self.model_path = model_config.model_path + self.served_model_name = model_config.served_model_name self.hf_config = get_config( self.model_path, - trust_remote_code=server_args.trust_remote_code, - model_overide_args=model_overide_args, + trust_remote_code=model_config.trust_remote_code, + model_override_args=model_config.model_override_args, ) self.is_generation = is_generation_model(self.hf_config.architectures) - if server_args.context_length is not None: - self.context_len = server_args.context_length + if model_config.context_length is not None: + self.context_len = model_config.context_length else: self.context_len = get_context_length(self.hf_config) - if server_args.skip_tokenizer_init: + if model_config.skip_tokenizer_init: self.tokenizer = self.processor = None else: if is_multimodal_model(self.model_path): self.processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, + model_config.tokenizer_path, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, ) self.tokenizer = self.processor.tokenizer os.environ["TOKENIZERS_PARALLELISM"] = "false" self.executor = concurrent.futures.ProcessPoolExecutor( initializer=init_global_processor, mp_context=mp.get_context("fork"), - initargs=(server_args,), + initargs=(model_config,), ) else: self.tokenizer = get_tokenizer( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, + model_config.tokenizer_path, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, ) self.to_create_loop = True self.handle_loop_task = None self.should_stop_loop = False self.rid_to_state: Dict[str, ReqState] = {} + self.observability_config = observability_config async def get_pixel_values(self, image_data): aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) @@ -454,7 +453,7 @@ async def _wait_for_response( out = state.out_list[-1] # Log requests - if self.server_args.log_requests and state.finished: + if self.observability_config.log_requests and state.finished: if obj.text is None: in_obj = {"input_ids": obj.input_ids} else: @@ -520,7 +519,9 @@ async def abort_request(): def create_handle_loop(self): self.to_create_loop = False loop = asyncio.get_event_loop() - assert self.handle_loop_task is None, "handle_loop_task exists when callig create_handle_loop" + assert ( + self.handle_loop_task is None + ), "handle_loop_task exists when callig create_handle_loop" self.handle_loop_task = loop.create_task(self.handle_loop()) async def handle_loop(self): @@ -530,7 +531,9 @@ async def handle_loop(self): try: while True: if self.should_stop_loop and not poller.poll(timeout=0): - logger.info("No more messages and shutdown requested, exiting loop.") + logger.info( + "No more messages and shutdown requested, exiting loop." + ) break # any new events? @@ -637,7 +640,7 @@ def shutdown(self): without synchronization prmitive. This API is synchronous, it will set the should_stop_loop flag to - bring the event loop down, and closes the sockets and the ZMQ context + bring the event loop down, and closes the sockets and the ZMQ context in a safe manner. """ # This flags the handle_loop() to stop(when finishing its last job.) @@ -649,7 +652,7 @@ def shutdown(self): async def _shutdown_async(self): """Asynchronous part of shutdown logic. - This is not an exposed API. (Shall we do a async shutdown??) + This is not an exposed API. (Shall we do a async shutdown??) """ # Wait for the handle_loop() is done, which means # self.recv_from_detokenizer is finished with its job. @@ -663,7 +666,7 @@ async def _shutdown_async(self): if not self.send_to_router.closed: self.send_to_router.close() logger.info("send_to_router socket closed.") - + # Now close the receiver if not self.recv_from_detokenizer.closed: self.recv_from_detokenizer.close() @@ -680,13 +683,13 @@ async def _shutdown_async(self): global global_processor -def init_global_processor(server_args: ServerArgs): +def init_global_processor(model_config: ModelConfig): global global_processor transformers.logging.set_verbosity_error() global_processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, + model_config.tokenizer_path, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, ) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index b6cfa68bd4a..8e6ab44a65c 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -28,6 +28,12 @@ import torch.distributed as dist from sglang.global_config import global_config +from sglang.srt.config import ( + ModelConfig, + OptimizationConfig, + ParallelConfig, + ScheduleConfig, +) from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer @@ -49,10 +55,8 @@ ) from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.radix_cache import RadixCache -from sglang.srt.model_config import ModelConfig from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner -from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( is_multimodal_model, set_random_seed, @@ -71,67 +75,62 @@ def __init__( self, gpu_id: int, tp_rank: int, - server_args: ServerArgs, - nccl_port: int, - model_overide_args: dict, + model_config: ModelConfig, + parallel_config: ParallelConfig, + schedule_config: ScheduleConfig, + optimization_config: OptimizationConfig, ): suppress_other_loggers() # Copy arguments self.gpu_id = gpu_id self.tp_rank = tp_rank - self.tp_size = server_args.tp_size - self.dp_size = server_args.dp_size - self.schedule_policy = server_args.schedule_policy - self.disable_regex_jump_forward = server_args.disable_regex_jump_forward + self.tp_size = parallel_config.tp_size + self.dp_size = parallel_config.dp_size + self.schedule_policy = schedule_config.schedule_policy + self.disable_regex_jump_forward = optimization_config.disable_regex_jump_forward # Init model and tokenizer - self.model_config = ModelConfig( - server_args.model_path, - server_args.trust_remote_code, - context_length=server_args.context_length, - model_overide_args=model_overide_args, - ) + self.model_config = model_config self.model_runner = ModelRunner( - model_config=self.model_config, - mem_fraction_static=server_args.mem_fraction_static, + model_config=model_config, + optimization_config=optimization_config, + parallel_config=parallel_config, + schedule_config=schedule_config, gpu_id=gpu_id, tp_rank=tp_rank, - tp_size=server_args.tp_size, - nccl_port=nccl_port, - server_args=server_args, ) - if server_args.skip_tokenizer_init: + if model_config.skip_tokenizer_init: self.tokenizer = self.processor = None else: - if is_multimodal_model(server_args.model_path): + if is_multimodal_model(model_config.model_path): self.processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, + model_config.tokenizer_path, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, ) self.tokenizer = self.processor.tokenizer else: self.tokenizer = get_tokenizer( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, + model_config.tokenizer_path, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, ) self.max_total_num_tokens = self.model_runner.max_total_num_tokens - self.max_prefill_tokens = server_args.max_prefill_tokens + self.max_prefill_tokens = schedule_config.max_prefill_tokens self.max_running_requests = min( ( self.max_total_num_tokens // 2 - if server_args.max_running_requests is None - else server_args.max_running_requests + if schedule_config.max_running_requests is None + else schedule_config.max_running_requests ), self.model_runner.req_to_token_pool.size - 1, ) self.max_req_input_len = min( - self.model_config.context_len - 1, + self.model_config.context_length - 1, self.max_total_num_tokens - 1, ) - set_random_seed(server_args.random_seed) + set_random_seed(model_config.random_seed) # Print info logger.info( @@ -139,13 +138,13 @@ def __init__( f"max_total_num_tokens={self.max_total_num_tokens}, " f"max_prefill_tokens={self.max_prefill_tokens}, " f"max_running_requests={self.max_running_requests}, " - f"context_len={self.model_config.context_len}" + f"context_len={self.model_config.context_length}" ) # Init cache if ( - server_args.chunked_prefill_size is not None - and server_args.disable_radix_cache + schedule_config.chunked_prefill_size is not None + and optimization_config.disable_radix_cache ): self.tree_cache = ChunkCache( req_to_token_pool=self.model_runner.req_to_token_pool, @@ -155,7 +154,7 @@ def __init__( self.tree_cache = RadixCache( req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool, - disable=server_args.disable_radix_cache, + disable=optimization_config.disable_radix_cache, ) self.tree_cache_metrics = {"total": 0, "hit": 0} self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache) @@ -167,36 +166,37 @@ def __init__( self.running_batch: ScheduleBatch = None self.out_pyobjs = [] self.decode_forward_ct = 0 - self.stream_interval = server_args.stream_interval + self.stream_interval = model_config.stream_interval self.num_generated_tokens = 0 self.last_stats_tic = time.time() # Chunked prefill - self.chunked_prefill_size = server_args.chunked_prefill_size + self.chunked_prefill_size = schedule_config.chunked_prefill_size self.current_inflight_req = None self.is_mixed_chunk = ( - self.chunked_prefill_size is not None and server_args.enable_mixed_chunk + self.chunked_prefill_size is not None + and optimization_config.enable_mixed_chunk ) # Init the FSM cache for constrained generation - if not server_args.skip_tokenizer_init: + if not model_config.skip_tokenizer_init: self.regex_fsm_cache = FSMCache( - server_args.tokenizer_path, + model_config.tokenizer_path, { - "tokenizer_mode": server_args.tokenizer_mode, - "trust_remote_code": server_args.trust_remote_code, + "tokenizer_mode": model_config.tokenizer_mode, + "trust_remote_code": model_config.trust_remote_code, }, - skip_tokenizer_init=server_args.skip_tokenizer_init, + skip_tokenizer_init=model_config.skip_tokenizer_init, ) self.jump_forward_cache = JumpForwardCache() # Init new token estimation assert ( - server_args.schedule_conservativeness >= 0 + schedule_config.schedule_conservativeness >= 0 ), "Invalid schedule_conservativeness" self.min_new_token_ratio = min( global_config.base_min_new_token_ratio - * server_args.schedule_conservativeness, + * schedule_config.schedule_conservativeness, 1.0, ) self.new_token_ratio = self.min_new_token_ratio @@ -802,18 +802,20 @@ def abort_request(self, recv_req): def run_tp_server( gpu_id: int, tp_rank: int, - server_args: ServerArgs, - nccl_port: int, - model_overide_args: dict, + model_config: ModelConfig, + parallel_config: ParallelConfig, + schedule_config: ScheduleConfig, + optimization_config: OptimizationConfig, ): """Run a tensor parallel server.""" try: model_server = ModelTpServer( gpu_id, tp_rank, - server_args, - nccl_port, - model_overide_args, + model_config, + parallel_config, + schedule_config, + optimization_config, ) tp_cpu_group = model_server.model_runner.tp_group.cpu_group @@ -828,16 +830,24 @@ def run_tp_server( def launch_tp_servers( gpu_ids: List[int], tp_rank_range: List[int], - server_args: ServerArgs, - nccl_port: int, - model_overide_args: dict, + model_config: ModelConfig, + parallel_config: ParallelConfig, + schedule_config: ScheduleConfig, + optimization_config: OptimizationConfig, ): """Launch multiple tensor parallel servers.""" procs = [] for i in tp_rank_range: proc = multiprocessing.Process( target=run_tp_server, - args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args), + args=( + gpu_ids[i], + i, + model_config, + parallel_config, + schedule_config, + optimization_config, + ), ) proc.start() procs.append(proc) diff --git a/python/sglang/srt/model_config.py b/python/sglang/srt/model_config.py deleted file mode 100644 index ed496515cd3..00000000000 --- a/python/sglang/srt/model_config.py +++ /dev/null @@ -1,162 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -from enum import IntEnum, auto -from typing import Optional - -from transformers import PretrainedConfig - -from sglang.srt.hf_transformers_utils import get_config, get_context_length - - -class AttentionArch(IntEnum): - MLA = auto() - MHA = auto() - - -class ModelConfig: - def __init__( - self, - path: str, - trust_remote_code: bool = True, - revision: Optional[str] = None, - context_length: Optional[int] = None, - model_overide_args: Optional[dict] = None, - ) -> None: - self.path = path - self.trust_remote_code = trust_remote_code - self.revision = revision - self.model_overide_args = model_overide_args - self.hf_config = get_config( - self.path, - trust_remote_code, - revision, - model_overide_args=model_overide_args, - ) - self.hf_text_config = get_hf_text_config(self.hf_config) - if context_length is not None: - self.context_len = context_length - else: - self.context_len = get_context_length(self.hf_config) - - # Unify the config keys for hf_config - self.head_dim = getattr( - self.hf_config, - "head_dim", - self.hf_config.hidden_size // self.hf_config.num_attention_heads, - ) - - # FIXME: temporary special judge for deepseek v2 MLA architecture - if "DeepseekV2ForCausalLM" in self.hf_config.architectures: - self.head_dim = 256 - self.attention_arch = AttentionArch.MLA - self.kv_lora_rank = self.hf_config.kv_lora_rank - self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim - else: - self.attention_arch = AttentionArch.MHA - - self.num_attention_heads = self.hf_config.num_attention_heads - self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None) - - # for Dbrx and MPT models - if self.hf_config.model_type in ["dbrx", "mpt"]: - self.num_key_value_heads = getattr( - self.hf_config.attn_config, "kv_n_heads", None - ) - - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - self.hidden_size = self.hf_config.hidden_size - self.num_hidden_layers = self.hf_config.num_hidden_layers - self.vocab_size = self.hf_config.vocab_size - - # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 - def get_total_num_kv_heads(self) -> int: - """Returns the total number of KV heads.""" - # For GPTBigCode & Falcon: - # NOTE: for falcon, when new_decoder_architecture is True, the - # multi_query flag is ignored and we use n_head_kv for the number of - # KV heads. - falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] - new_decoder_arch_falcon = ( - self.hf_config.model_type in falcon_model_types - and getattr(self.hf_config, "new_decoder_architecture", False) - ) - if not new_decoder_arch_falcon and getattr( - self.hf_text_config, "multi_query", False - ): - # Multi-query attention, only one KV head. - # Currently, tensor parallelism is not supported in this case. - return 1 - - # For DBRX and MPT - if self.hf_config.model_type in ["mpt"]: - if "kv_n_heads" in self.hf_config.attn_config: - return self.hf_config.attn_config["kv_n_heads"] - return self.hf_config.num_attention_heads - if self.hf_config.model_type in ["dbrx"]: - return getattr( - self.hf_config.attn_config, - "kv_n_heads", - self.hf_config.num_attention_heads, - ) - - attributes = [ - # For Falcon: - "n_head_kv", - "num_kv_heads", - # For LLaMA-2: - "num_key_value_heads", - # For ChatGLM: - "multi_query_group_num", - ] - for attr in attributes: - num_kv_heads = getattr(self.hf_text_config, attr, None) - if num_kv_heads is not None: - return num_kv_heads - - # For non-grouped-query attention models, the number of KV heads is - # equal to the number of attention heads. - return self.hf_text_config.num_attention_heads - - # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328 - def get_num_kv_heads(self, tensor_parallel_size) -> int: - """Returns the number of KV heads per GPU.""" - total_num_kv_heads = self.get_total_num_kv_heads() - # If tensor parallelism is used, we divide the number of KV heads by - # the tensor parallel size. We will replicate the KV heads in the - # case where the number of KV heads is smaller than the tensor - # parallel size so each GPU has at least one KV head. - return max(1, total_num_kv_heads // tensor_parallel_size) - - -def get_hf_text_config(config: PretrainedConfig): - """Get the "sub" config relevant to llm for multi modal models. - No op for pure text models. - """ - class_name = config.architectures[0] - if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"): - # We support non-hf version of llava models, so we do not want to - # read the wrong values from the unused default text_config. - return config - - if hasattr(config, "text_config"): - # The code operates under the assumption that text_config should have - # `num_attention_heads` (among others). Assert here to fail early - # if transformers config doesn't align with this assumption. - assert hasattr(config.text_config, "num_attention_heads") - return config.text_config - else: - return config diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index af39065cfa5..60ccbe75450 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -111,7 +111,7 @@ def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile): (self.max_bs + 1,), dtype=torch.int32, device="cuda" ) self.flashinfer_kv_indices = torch.zeros( - (self.max_bs * model_runner.model_config.context_len,), + (self.max_bs * model_runner.model_config.context_length,), dtype=torch.int32, device="cuda", ) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 3cf68eab24d..efeb158465c 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -174,18 +174,18 @@ def from_schedule_batch( if ( forward_mode != ForwardMode.DECODE - or model_runner.server_args.disable_flashinfer + or model_runner.optimization_config.disable_flashinfer ): ret.total_num_tokens = int(torch.sum(ret.seq_lens)) if forward_mode != ForwardMode.DECODE: ret.init_multimuldal_info(batch) - if model_runner.server_args.disable_flashinfer: + if model_runner.optimization_config.disable_flashinfer: ret.init_triton_args(batch) flashinfer_use_ragged = False - if not model_runner.server_args.disable_flashinfer: + if not model_runner.optimization_config.disable_flashinfer: if ( forward_mode != ForwardMode.DECODE and int(torch.sum(ret.seq_lens)) > 4096 diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b74a19e60df..f590b3257a3 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -43,15 +43,20 @@ from vllm.model_executor.models import ModelRegistry from sglang.global_config import global_config +from sglang.srt.config import ( + AttentionArch, + ModelConfig, + OptimizationConfig, + ParallelConfig, + ScheduleConfig, +) from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPool, MLATokenToKVPool, ReqToTokenPool, ) -from sglang.srt.model_config import AttentionArch from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata -from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( get_available_gpu_memory, is_generation_model, @@ -68,29 +73,30 @@ class ModelRunner: def __init__( self, - model_config, - mem_fraction_static: float, + model_config: ModelConfig, + optimization_config: OptimizationConfig, + parallel_config: ParallelConfig, + schedule_config: ScheduleConfig, gpu_id: int, tp_rank: int, - tp_size: int, - nccl_port: int, - server_args: ServerArgs, ): # Parse args self.model_config = model_config - self.mem_fraction_static = mem_fraction_static + self.optimization_config = optimization_config + self.parallel_config = parallel_config + self.schedule_config = schedule_config + self.mem_fraction_static = schedule_config.mem_fraction_static self.gpu_id = gpu_id self.tp_rank = tp_rank - self.tp_size = tp_size - self.nccl_port = nccl_port - self.server_args = server_args + self.tp_size = parallel_config.tp_size + self.nccl_port = parallel_config.nccl_ports[0] self.is_multimodal_model = is_multimodal_model(self.model_config) global_server_args_dict.update( { - "disable_flashinfer": server_args.disable_flashinfer, - "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling, - "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, - "enable_mla": server_args.enable_mla, + "disable_flashinfer": optimization_config.disable_flashinfer, + "disable_flashinfer_sampling": optimization_config.disable_flashinfer_sampling, + "attention_reduce_in_fp32": optimization_config.attention_reduce_in_fp32, + "enable_mla": optimization_config.enable_mla, } ) @@ -98,11 +104,11 @@ def __init__( torch.cuda.set_device(self.gpu_id) logger.info(f"[gpu={self.gpu_id}] Init nccl begin.") - if not server_args.enable_p2p_check: + if not optimization_config.enable_p2p_check: monkey_patch_vllm_p2p_access_check(self.gpu_id) - if server_args.nccl_init_addr: - nccl_init_method = f"tcp://{server_args.nccl_init_addr}" + if parallel_config.nccl_init_addr: + nccl_init_method = f"tcp://{parallel_config.nccl_init_addr}" else: nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}" init_distributed_environment( @@ -132,8 +138,8 @@ def __init__( self.load_model() self.init_memory_pool( total_gpu_memory, - server_args.max_num_reqs, - server_args.max_total_tokens, + schedule_config.max_num_reqs, + schedule_config.max_total_tokens, ) self.init_cublas() self.init_flashinfer() @@ -156,14 +162,14 @@ def load_model(self): monkey_patch_vllm_dummy_weight_loader() device_config = DeviceConfig() - load_config = LoadConfig(load_format=self.server_args.load_format) + load_config = LoadConfig(load_format=self.model_config.load_format) vllm_model_config = VllmModelConfig( - model=self.server_args.model_path, - quantization=self.server_args.quantization, + model=self.model_config.model_path, + quantization=self.model_config.quantization, tokenizer=None, tokenizer_mode=None, - trust_remote_code=self.server_args.trust_remote_code, - dtype=self.server_args.dtype, + trust_remote_code=self.model_config.trust_remote_code, + dtype=self.model_config.dtype, seed=42, skip_tokenizer_init=True, ) @@ -175,8 +181,8 @@ def load_model(self): monkey_patch_vllm_qvk_linear_loader() self.dtype = vllm_model_config.dtype - if self.model_config.model_overide_args is not None: - vllm_model_config.hf_config.update(self.model_config.model_overide_args) + if self.model_config.model_override_args is not None: + vllm_model_config.hf_config.update(self.model_config.model_override_args) self.model = get_model( model_config=vllm_model_config, @@ -210,7 +216,7 @@ def profile_max_num_token(self, total_gpu_memory): ) if ( self.model_config.attention_arch == AttentionArch.MLA - and self.server_args.enable_mla + and self.optimization_config.enable_mla ): cell_size = ( (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) @@ -253,7 +259,9 @@ def init_memory_pool( max_num_reqs = min( max( int( - self.max_total_num_tokens / self.model_config.context_len * 512 + self.max_total_num_tokens + / self.model_config.context_length + * 512 ), 2048, ), @@ -262,11 +270,11 @@ def init_memory_pool( self.req_to_token_pool = ReqToTokenPool( max_num_reqs, - self.model_config.context_len + 8, + self.model_config.context_length + 8, ) if ( self.model_config.attention_arch == AttentionArch.MLA - and self.server_args.enable_mla + and self.optimization_config.enable_mla ): self.token_to_kv_pool = MLATokenToKVPool( self.max_total_num_tokens, @@ -277,7 +285,7 @@ def init_memory_pool( ) logger.info("using MLA Triton implementaion, flashinfer is disabled") # FIXME: temporarily only Triton MLA is supported - self.server_args.disable_flashinfer = True + self.optimization_config.disable_flashinfer = True else: self.token_to_kv_pool = MHATokenToKVPool( self.max_total_num_tokens, @@ -301,7 +309,7 @@ def init_cublas(self): return c def init_flashinfer(self): - if self.server_args.disable_flashinfer: + if self.optimization_config.disable_flashinfer: assert ( self.sliding_window_size is None ), "turn on flashinfer to support window attention" @@ -363,7 +371,10 @@ def init_flashinfer(self): def init_cuda_graphs(self): from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner - if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer: + if ( + self.optimization_config.disable_cuda_graph + or self.optimization_config.disable_flashinfer + ): self.cuda_graph_runner = None return @@ -374,7 +385,7 @@ def init_cuda_graphs(self): self.cuda_graph_runner = CudaGraphRunner( self, max_batch_size_to_capture=max(batch_size_list), - use_torch_compile=self.server_args.enable_torch_compile, + use_torch_compile=self.optimization_config.enable_torch_compile, ) try: self.cuda_graph_runner.capture(batch_size_list) diff --git a/python/sglang/srt/sampling_params.py b/python/sglang/srt/sampling_params.py index 770bb2e87b1..0fc0944365f 100644 --- a/python/sglang/srt/sampling_params.py +++ b/python/sglang/srt/sampling_params.py @@ -141,4 +141,4 @@ def to_dict(self): "spaces_between_special_tokens": self.spaces_between_special_tokens, "regex": self.regex, "n": self.n, - } \ No newline at end of file + } diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py deleted file mode 100644 index 9028c12309b..00000000000 --- a/python/sglang/srt/server.py +++ /dev/null @@ -1,590 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -""" -The entry point of inference server. -SRT = SGLang Runtime. -""" - -import asyncio -import dataclasses -import json -import logging -import multiprocessing as mp -import os -import sys -import threading -import time -from http import HTTPStatus -from typing import Dict, List, Optional, Union - -# Fix a bug of Python threading -setattr(threading, "_register_atexit", lambda *args, **kwargs: None) - -import aiohttp -import requests -import uvicorn -import uvloop -from fastapi import FastAPI, File, Form, Request, UploadFile -from fastapi.responses import JSONResponse, Response, StreamingResponse - -from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint -from sglang.srt.constrained import disable_cache -from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.managers.controller_multi import ( - start_controller_process as start_controller_process_multi, -) -from sglang.srt.managers.controller_single import launch_tp_servers -from sglang.srt.managers.controller_single import ( - start_controller_process as start_controller_process_single, -) -from sglang.srt.managers.detokenizer_manager import start_detokenizer_process -from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput -from sglang.srt.managers.tokenizer_manager import TokenizerManager -from sglang.srt.openai_api.adapter import ( - load_chat_template_for_openai_api, - v1_batches, - v1_chat_completions, - v1_completions, - v1_delete_file, - v1_embeddings, - v1_files_create, - v1_retrieve_batch, - v1_retrieve_file, - v1_retrieve_file_content, -) -from sglang.srt.openai_api.protocol import ModelCard, ModelList -from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import ( - add_api_key_middleware, - allocate_init_ports, - assert_pkg_version, - enable_show_time_cost, - kill_child_process, - maybe_set_triton_cache_manager, - prepare_model, - prepare_tokenizer, - set_ulimit, -) -from sglang.utils import get_exception_traceback - -logger = logging.getLogger(__name__) - -asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - - -app = FastAPI() -tokenizer_manager = None - - -@app.get("/health") -async def health() -> Response: - """Health check.""" - return Response(status_code=200) - - -@app.get("/get_model_info") -async def get_model_info(): - result = { - "model_path": tokenizer_manager.model_path, - "is_generation": tokenizer_manager.is_generation, - } - return result - - -@app.get("/get_server_args") -async def get_server_args(): - return dataclasses.asdict(tokenizer_manager.server_args) - - -@app.get("/flush_cache") -async def flush_cache(): - tokenizer_manager.flush_cache() - return Response( - content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", - status_code=200, - ) - - -async def generate_request(obj: GenerateReqInput, request: Request): - """Handle a generate request.""" - if obj.stream: - - async def stream_results(): - try: - async for out in tokenizer_manager.generate_request(obj, request): - yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" - except ValueError as e: - out = {"error": {"message": str(e)}} - yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" - yield "data: [DONE]\n\n" - - return StreamingResponse( - stream_results(), - media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(obj), - ) - else: - try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() - return ret - except ValueError as e: - return JSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) - - -app.post("/generate")(generate_request) -app.put("/generate")(generate_request) - - -async def encode_request(obj: EmbeddingReqInput, request: Request): - """Handle an embedding request.""" - try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() - return ret - except ValueError as e: - return JSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) - - -app.post("/encode")(encode_request) -app.put("/encode")(encode_request) - - -@app.post("/v1/completions") -async def openai_v1_completions(raw_request: Request): - return await v1_completions(tokenizer_manager, raw_request) - - -@app.post("/v1/chat/completions") -async def openai_v1_chat_completions(raw_request: Request): - return await v1_chat_completions(tokenizer_manager, raw_request) - - -@app.post("/v1/embeddings") -async def openai_v1_embeddings(raw_request: Request): - response = await v1_embeddings(tokenizer_manager, raw_request) - return response - - -@app.get("/v1/models") -def available_models(): - """Show available models.""" - served_model_names = [tokenizer_manager.served_model_name] - model_cards = [] - for served_model_name in served_model_names: - model_cards.append(ModelCard(id=served_model_name, root=served_model_name)) - return ModelList(data=model_cards) - - -@app.post("/v1/files") -async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): - return await v1_files_create( - file, purpose, tokenizer_manager.server_args.file_storage_pth - ) - - -@app.delete("/v1/files/{file_id}") -async def delete_file(file_id: str): - # https://platform.openai.com/docs/api-reference/files/delete - return await v1_delete_file(file_id) - - -@app.post("/v1/batches") -async def openai_v1_batches(raw_request: Request): - return await v1_batches(tokenizer_manager, raw_request) - - -@app.get("/v1/batches/{batch_id}") -async def retrieve_batch(batch_id: str): - return await v1_retrieve_batch(batch_id) - - -@app.get("/v1/files/{file_id}") -async def retrieve_file(file_id: str): - # https://platform.openai.com/docs/api-reference/files/retrieve - return await v1_retrieve_file(file_id) - - -@app.get("/v1/files/{file_id}/content") -async def retrieve_file_content(file_id: str): - # https://platform.openai.com/docs/api-reference/files/retrieve-contents - return await v1_retrieve_file_content(file_id) - - -def launch_server( - server_args: ServerArgs, - model_overide_args: Optional[dict] = None, - pipe_finish_writer: Optional[mp.connection.Connection] = None, -): - """Launch an HTTP server.""" - global tokenizer_manager - - logging.basicConfig( - level=getattr(logging, server_args.log_level.upper()), - format="%(message)s", - ) - - server_args.check_server_args() - _set_envs_and_config(server_args) - - # Allocate ports - server_args.port, server_args.additional_ports = allocate_init_ports( - server_args.port, - server_args.additional_ports, - server_args.dp_size, - ) - ports = server_args.additional_ports - port_args = PortArgs( - tokenizer_port=ports[0], - controller_port=ports[1], - detokenizer_port=ports[2], - nccl_ports=ports[3:], - ) - logger.info(f"{server_args=}") - - # Use model from www.modelscope.cn, first download the model. - server_args.model_path = prepare_model(server_args.model_path) - server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path) - - # Launch processes for multi-node tensor parallelism - if server_args.nnodes > 1: - if server_args.node_rank != 0: - tp_size_local = server_args.tp_size // server_args.nnodes - gpu_ids = [ - i for _ in range(server_args.nnodes) for i in range(tp_size_local) - ] - tp_rank_range = list( - range( - server_args.node_rank * tp_size_local, - (server_args.node_rank + 1) * tp_size_local, - ) - ) - procs = launch_tp_servers( - gpu_ids, - tp_rank_range, - server_args, - ports[3], - model_overide_args, - ) - while True: - pass - - # Launch processes - tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) - if server_args.chat_template: - load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) - pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) - pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) - - if server_args.dp_size == 1: - start_process = start_controller_process_single - else: - start_process = start_controller_process_multi - proc_controller = mp.Process( - target=start_process, - args=(server_args, port_args, pipe_controller_writer, model_overide_args), - ) - proc_controller.start() - proc_detoken = mp.Process( - target=start_detokenizer_process, - args=( - server_args, - port_args, - pipe_detoken_writer, - ), - ) - proc_detoken.start() - - # Wait for the model to finish loading - controller_init_state = pipe_controller_reader.recv() - detoken_init_state = pipe_detoken_reader.recv() - - if controller_init_state != "init ok" or detoken_init_state != "init ok": - proc_controller.kill() - proc_detoken.kill() - print( - f"Initialization failed. controller_init_state: {controller_init_state}", - flush=True, - ) - print( - f"Initialization failed. detoken_init_state: {detoken_init_state}", - flush=True, - ) - sys.exit(1) - assert proc_controller.is_alive() and proc_detoken.is_alive() - - # Add api key authorization - if server_args.api_key: - add_api_key_middleware(app, server_args.api_key) - - # Send a warmup request - t = threading.Thread( - target=_wait_and_warmup, args=(server_args, pipe_finish_writer) - ) - t.start() - - # Listen for requests - try: - uvicorn.run( - app, - host=server_args.host, - port=server_args.port, - log_level=server_args.log_level_http or server_args.log_level, - timeout_keep_alive=5, - loop="uvloop", - ) - finally: - t.join() - - -def _set_envs_and_config(server_args: ServerArgs): - # Set global environments - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - os.environ["NCCL_CUMEM_ENABLE"] = "0" - os.environ["NCCL_NVLS_ENABLE"] = "0" - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - - # Set ulimit - set_ulimit() - - # Enable show time cost for debugging - if server_args.show_time_cost: - enable_show_time_cost() - - # Disable disk cache - if server_args.disable_disk_cache: - disable_cache() - - # Fix triton bugs - if server_args.tp_size * server_args.dp_size > 1: - # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. - maybe_set_triton_cache_manager() - - # Check flashinfer version - if not server_args.disable_flashinfer: - assert_pkg_version( - "flashinfer", - "0.1.5", - "Please uninstall the old version and " - "reinstall the latest version by following the instructions " - "at https://docs.flashinfer.ai/installation.html.", - ) - - -def _wait_and_warmup(server_args, pipe_finish_writer): - headers = {} - url = server_args.url() - if server_args.api_key: - headers["Authorization"] = f"Bearer {server_args.api_key}" - - # Wait until the server is launched - success = False - for _ in range(120): - time.sleep(1) - try: - res = requests.get(url + "/get_model_info", timeout=5, headers=headers) - assert res.status_code == 200, f"{res}" - success = True - break - except (AssertionError, requests.exceptions.RequestException) as e: - last_traceback = get_exception_traceback() - pass - model_info = res.json() - - if not success: - if pipe_finish_writer is not None: - pipe_finish_writer.send(last_traceback) - print(f"Initialization failed. warmup error: {last_traceback}", flush=True) - sys.exit(1) - - # Send a warmup request - request_name = "/generate" if model_info["is_generation"] else "/encode" - max_new_tokens = 8 if model_info["is_generation"] else 1 - json_data = { - "sampling_params": { - "temperature": 0, - "max_new_tokens": max_new_tokens, - }, - } - if server_args.skip_tokenizer_init: - json_data["input_ids"] = [10, 11, 12] - else: - json_data["text"] = "The capital city of France is" - - try: - for _ in range(server_args.dp_size): - res = requests.post( - url + request_name, - json=json_data, - headers=headers, - timeout=600, - ) - assert res.status_code == 200, f"{res}" - except Exception as e: - last_traceback = get_exception_traceback() - if pipe_finish_writer is not None: - pipe_finish_writer.send(last_traceback) - print(f"Initialization failed. warmup error: {last_traceback}", flush=True) - sys.exit(1) - - logger.info("The server is fired up and ready to roll!") - if pipe_finish_writer is not None: - pipe_finish_writer.send("init ok") - - -class Runtime: - """ - A wrapper for the server. - This is used for launching the server in a python program without - using the commond line interface. - """ - - def __init__( - self, - log_level: str = "error", - model_overide_args: Optional[dict] = None, - *args, - **kwargs, - ): - """See the arguments in server_args.py::ServerArgs""" - self.server_args = ServerArgs(*args, log_level=log_level, **kwargs) - - # Pre-allocate ports - self.server_args.port, self.server_args.additional_ports = allocate_init_ports( - self.server_args.port, - self.server_args.additional_ports, - self.server_args.dp_size, - ) - - self.url = self.server_args.url() - self.generate_url = ( - f"http://{self.server_args.host}:{self.server_args.port}/generate" - ) - - self.pid = None - pipe_reader, pipe_writer = mp.Pipe(duplex=False) - proc = mp.Process( - target=launch_server, - args=(self.server_args, model_overide_args, pipe_writer), - ) - proc.start() - pipe_writer.close() - self.pid = proc.pid - - try: - init_state = pipe_reader.recv() - except EOFError: - init_state = "" - - if init_state != "init ok": - self.shutdown() - raise RuntimeError( - "Initialization failed. Please see the error messages above." - ) - - self.endpoint = RuntimeEndpoint(self.url) - - def shutdown(self): - if self.pid is not None: - kill_child_process(self.pid) - self.pid = None - - def cache_prefix(self, prefix: str): - self.endpoint.cache_prefix(prefix) - - def get_tokenizer(self): - return get_tokenizer( - self.server_args.tokenizer_path, - tokenizer_mode=self.server_args.tokenizer_mode, - trust_remote_code=self.server_args.trust_remote_code, - ) - - async def async_generate( - self, - prompt: str, - sampling_params: Optional[Dict] = None, - ): - if self.server_args.skip_tokenizer_init: - json_data = { - "input_ids": prompt, - "sampling_params": sampling_params, - "stream": True, - } - else: - json_data = { - "text": prompt, - "sampling_params": sampling_params, - "stream": True, - } - pos = 0 - - timeout = aiohttp.ClientTimeout(total=3 * 3600) - async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.post(self.generate_url, json=json_data) as response: - async for chunk, _ in response.content.iter_chunks(): - chunk = chunk.decode("utf-8") - if chunk and chunk.startswith("data:"): - if chunk == "data: [DONE]\n\n": - break - data = json.loads(chunk[5:].strip("\n")) - if hasattr(data, "text"): - cur = data["text"][pos:] - if cur: - yield cur - pos += len(cur) - else: - yield data - - add_request = async_generate - - def generate( - self, - prompt: str, - sampling_params: Optional[Dict] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - top_logprobs_num: Optional[Union[List[int], int]] = None, - ): - json_data = { - "text": prompt, - "sampling_params": sampling_params, - "return_logprob": return_logprob, - "top_logprobs_num": top_logprobs_num, - } - response = requests.post( - self.url + "/generate", - json=json_data, - ) - return json.dumps(response.json()) - - def encode( - self, - prompt: str, - ): - json_data = { - "text": prompt, - } - response = requests.post( - self.url + "/encode", - json=json_data, - ) - return json.dumps(response.json()) - - def __del__(self): - self.shutdown() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py deleted file mode 100644 index 99ecff6a588..00000000000 --- a/python/sglang/srt/server_args.py +++ /dev/null @@ -1,468 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -"""The arguments of the server.""" - -import argparse -import dataclasses -import logging -import random -from typing import List, Optional, Union - -logger = logging.getLogger(__name__) - - -@dataclasses.dataclass -class ServerArgs: - # Model and tokenizer - model_path: str - tokenizer_path: Optional[str] = None - tokenizer_mode: str = "auto" - skip_tokenizer_init: bool = False - load_format: str = "auto" - dtype: str = "auto" - trust_remote_code: bool = True - context_length: Optional[int] = None - quantization: Optional[str] = None - served_model_name: Optional[str] = None - chat_template: Optional[str] = None - - # Port - host: str = "127.0.0.1" - port: int = 30000 - additional_ports: Optional[Union[List[int], int]] = None - - # Memory and scheduling - mem_fraction_static: Optional[float] = None - max_running_requests: Optional[int] = None - max_num_reqs: Optional[int] = None - max_total_tokens: Optional[int] = None - chunked_prefill_size: int = 8192 - max_prefill_tokens: int = 16384 - schedule_policy: str = "lpm" - schedule_conservativeness: float = 1.0 - - # Other runtime options - tp_size: int = 1 - stream_interval: int = 1 - random_seed: Optional[int] = None - - # Logging - log_level: str = "info" - log_level_http: Optional[str] = None - log_requests: bool = False - show_time_cost: bool = False - - # Other - api_key: Optional[str] = None - file_storage_pth: str = "SGLang_storage" - - # Data parallelism - dp_size: int = 1 - load_balance_method: str = "round_robin" - - # Optimization/debug options - disable_flashinfer: bool = False - disable_flashinfer_sampling: bool = False - disable_radix_cache: bool = False - disable_regex_jump_forward: bool = False - disable_cuda_graph: bool = False - disable_disk_cache: bool = False - enable_mixed_chunk: bool = False - enable_torch_compile: bool = False - enable_p2p_check: bool = False - enable_mla: bool = False - attention_reduce_in_fp32: bool = False - efficient_weight_load: bool = False - - # Distributed args - nccl_init_addr: Optional[str] = None - nnodes: int = 1 - node_rank: Optional[int] = None - - def __post_init__(self): - if self.tokenizer_path is None: - self.tokenizer_path = self.model_path - - if self.served_model_name is None: - self.served_model_name = self.model_path - - if self.chunked_prefill_size <= 0: - # Disable chunked prefill - self.chunked_prefill_size = None - - if self.mem_fraction_static is None: - if self.tp_size >= 16: - self.mem_fraction_static = 0.79 - elif self.tp_size >= 8: - self.mem_fraction_static = 0.83 - elif self.tp_size >= 4: - self.mem_fraction_static = 0.85 - elif self.tp_size >= 2: - self.mem_fraction_static = 0.87 - else: - self.mem_fraction_static = 0.88 - - if isinstance(self.additional_ports, int): - self.additional_ports = [self.additional_ports] - elif self.additional_ports is None: - self.additional_ports = [] - - if self.random_seed is None: - self.random_seed = random.randint(0, 1 << 30) - - @staticmethod - def add_cli_args(parser: argparse.ArgumentParser): - parser.add_argument( - "--model-path", - type=str, - help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.", - required=True, - ) - parser.add_argument( - "--tokenizer-path", - type=str, - default=ServerArgs.tokenizer_path, - help="The path of the tokenizer.", - ) - parser.add_argument( - "--host", type=str, default=ServerArgs.host, help="The host of the server." - ) - parser.add_argument( - "--port", type=int, default=ServerArgs.port, help="The port of the server." - ) - parser.add_argument( - "--additional-ports", - type=int, - nargs="*", - default=[], - help="The additional ports specified for the server.", - ) - parser.add_argument( - "--tokenizer-mode", - type=str, - default=ServerArgs.tokenizer_mode, - choices=["auto", "slow"], - help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", - ) - parser.add_argument( - "--skip-tokenizer-init", - action="store_true", - help="If set, skip init tokenizer and pass input_ids in generate request", - ) - parser.add_argument( - "--load-format", - type=str, - default=ServerArgs.load_format, - choices=["auto", "pt", "safetensors", "npcache", "dummy"], - help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling.", - ) - parser.add_argument( - "--dtype", - type=str, - default=ServerArgs.dtype, - choices=["auto", "half", "float16", "bfloat16", "float", "float32"], - help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', - ) - parser.add_argument( - "--trust-remote-code", - action="store_true", - help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", - ) - parser.add_argument( - "--context-length", - type=int, - default=ServerArgs.context_length, - help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).", - ) - parser.add_argument( - "--quantization", - type=str, - default=ServerArgs.quantization, - choices=[ - "awq", - "fp8", - "gptq", - "marlin", - "gptq_marlin", - "awq_marlin", - "squeezellm", - "bitsandbytes", - ], - help="The quantization method.", - ) - parser.add_argument( - "--served-model-name", - type=str, - default=ServerArgs.served_model_name, - help="Override the model name returned by the v1/models endpoint in OpenAI API server.", - ) - parser.add_argument( - "--chat-template", - type=str, - default=ServerArgs.chat_template, - help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.", - ) - parser.add_argument( - "--mem-fraction-static", - type=float, - default=ServerArgs.mem_fraction_static, - help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.", - ) - parser.add_argument( - "--max-running-requests", - type=int, - default=ServerArgs.max_running_requests, - help="The maximum number of running requests.", - ) - parser.add_argument( - "--max-num-reqs", - type=int, - default=ServerArgs.max_num_reqs, - help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.", - ) - parser.add_argument( - "--max-total-tokens", - type=int, - default=ServerArgs.max_total_tokens, - help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.", - ) - parser.add_argument( - "--chunked-prefill-size", - type=int, - default=ServerArgs.chunked_prefill_size, - help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill", - ) - parser.add_argument( - "--max-prefill-tokens", - type=int, - default=ServerArgs.max_prefill_tokens, - help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.", - ) - parser.add_argument( - "--schedule-policy", - type=str, - default=ServerArgs.schedule_policy, - choices=["lpm", "random", "fcfs", "dfs-weight"], - help="The scheduling policy of the requests.", - ) - parser.add_argument( - "--schedule-conservativeness", - type=float, - default=ServerArgs.schedule_conservativeness, - help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.", - ) - parser.add_argument( - "--tensor-parallel-size", - "--tp-size", - type=int, - default=ServerArgs.tp_size, - help="The tensor parallelism size.", - ) - parser.add_argument( - "--stream-interval", - type=int, - default=ServerArgs.stream_interval, - help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher", - ) - parser.add_argument( - "--random-seed", - type=int, - default=ServerArgs.random_seed, - help="The random seed.", - ) - parser.add_argument( - "--log-level", - type=str, - default=ServerArgs.log_level, - help="The logging level of all loggers.", - ) - parser.add_argument( - "--log-level-http", - type=str, - default=ServerArgs.log_level_http, - help="The logging level of HTTP server. If not set, reuse --log-level by default.", - ) - parser.add_argument( - "--log-requests", - action="store_true", - help="Log the inputs and outputs of all requests.", - ) - parser.add_argument( - "--show-time-cost", - action="store_true", - help="Show time cost of custom marks.", - ) - parser.add_argument( - "--api-key", - type=str, - default=ServerArgs.api_key, - help="Set API key of the server. It is also used in the OpenAI API compatible server.", - ) - parser.add_argument( - "--file-storage-pth", - type=str, - default=ServerArgs.file_storage_pth, - help="The path of the file storage in backend.", - ) - - # Data parallelism - parser.add_argument( - "--data-parallel-size", - "--dp-size", - type=int, - default=ServerArgs.dp_size, - help="The data parallelism size.", - ) - parser.add_argument( - "--load-balance-method", - type=str, - default=ServerArgs.load_balance_method, - help="The load balancing strategy for data parallelism.", - choices=[ - "round_robin", - "shortest_queue", - ], - ) - - # Multi-node distributed serving args - parser.add_argument( - "--nccl-init-addr", - type=str, - help="The nccl init address of multi-node server.", - ) - parser.add_argument( - "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes." - ) - parser.add_argument("--node-rank", type=int, help="The node rank.") - - # Optimization/debug options - parser.add_argument( - "--disable-flashinfer", - action="store_true", - help="Disable flashinfer attention kernels.", - ) - parser.add_argument( - "--disable-flashinfer-sampling", - action="store_true", - help="Disable flashinfer sampling kernels.", - ) - parser.add_argument( - "--disable-radix-cache", - action="store_true", - help="Disable RadixAttention for prefix caching.", - ) - parser.add_argument( - "--disable-regex-jump-forward", - action="store_true", - help="Disable regex jump-forward.", - ) - parser.add_argument( - "--disable-cuda-graph", - action="store_true", - help="Disable cuda graph.", - ) - parser.add_argument( - "--disable-disk-cache", - action="store_true", - help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", - ) - parser.add_argument( - "--enable-mixed-chunk", - action="store_true", - help="Enabling mixing prefill and decode in a chunked batch.", - ) - parser.add_argument( - "--enable-torch-compile", - action="store_true", - help="Optimize the model with torch.compile, experimental feature.", - ) - parser.add_argument( - "--enable-p2p-check", - action="store_true", - help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.", - ) - parser.add_argument( - "--enable-mla", - action="store_true", - help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2", - ) - parser.add_argument( - "--attention-reduce-in-fp32", - action="store_true", - help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels", - ) - parser.add_argument( - "--efficient-weight-load", - action="store_true", - help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", - ) - - @classmethod - def from_cli_args(cls, args: argparse.Namespace): - args.tp_size = args.tensor_parallel_size - args.dp_size = args.data_parallel_size - attrs = [attr.name for attr in dataclasses.fields(cls)] - return cls(**{attr: getattr(args, attr) for attr in attrs}) - - def url(self): - return f"http://{self.host}:{self.port}" - - def print_mode_args(self): - return ( - f"disable_flashinfer={self.disable_flashinfer}, " - f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, " - f"disable_radix_cache={self.disable_radix_cache}, " - f"disable_regex_jump_forward={self.disable_regex_jump_forward}, " - f"disable_disk_cache={self.disable_disk_cache}, " - ) - - def check_server_args(self): - assert ( - self.tp_size % self.nnodes == 0 - ), "tp_size must be divisible by number of nodes" - assert not ( - self.dp_size > 1 and self.node_rank is not None - ), "multi-node data parallel is not supported" - if "gemma-2" in self.model_path.lower(): - logger.info(f"When using sliding window in gemma-2, turn on flashinfer.") - self.disable_flashinfer = False - - -@dataclasses.dataclass -class PortArgs: - tokenizer_port: int - controller_port: int - detokenizer_port: int - nccl_ports: List[int] diff --git a/python/sglang/srt/serving/__init__.py b/python/sglang/srt/serving/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/sglang/srt/serving/engine.py b/python/sglang/srt/serving/engine.py index b84df5601a1..fe57dbfacf6 100644 --- a/python/sglang/srt/serving/engine.py +++ b/python/sglang/srt/serving/engine.py @@ -12,25 +12,21 @@ See the License for the specific language governing permissions and limitations under the License. """ + import asyncio import logging import multiprocessing as mp import sys -from typing import Optional, Union, List, Dict - -import argparse -from sglang.srt.server_args import ServerArgs +from dataclasses import fields +from typing import Dict, List, Optional, Union -from sglang.srt.serving.engine_args import EngineArgs -from sglang.srt.serving.server_args import ServerArgs from sglang.srt.config import ( ModelConfig, - ScheduleConfig, - ParallelConfig, + ObservabilityConfig, OptimizationConfig, + ParallelConfig, + ScheduleConfig, ) -from sglang.srt.sampling_params import SamplingParams -from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput from sglang.srt.managers.controller_multi import ( start_controller_process as start_controller_process_multi, ) @@ -39,40 +35,50 @@ start_controller_process as start_controller_process_single, ) from sglang.srt.managers.detokenizer_manager import start_detokenizer_process +from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager -from sglang.srt.utils import ( - prepare_model, - prepare_tokenizer, -) +from sglang.srt.sampling_params import SamplingParams +from sglang.srt.serving.engine_args import EngineArgs +from sglang.srt.utils import prepare_model, prepare_tokenizer logger = logging.getLogger(__name__) + class Engine: """ The core LLM Engine """ - def __init__(self, - server_args: ServerArgs, # get rid of it latter - model_config: ModelConfig, - schedule_config: ScheduleConfig, - parallel_config: ParallelConfig, - optimization_config: OptimizationConfig + + def __init__( + self, + model_config: ModelConfig, + schedule_config: ScheduleConfig, + parallel_config: ParallelConfig, + optimization_config: OptimizationConfig, + observability_config: ObservabilityConfig, ): self.model_config = model_config self.schedule_config = schedule_config self.parallel_config = parallel_config self.optimization_config = optimization_config + self.observability_config = observability_config # Use model from www.modelscope.cn, first download the model. self.model_config.model_path = prepare_model(self.model_config.model_path) - self.model_config.tokenizer_path = prepare_tokenizer(self.model_config.tokenizer_path) + self.model_config.tokenizer_path = prepare_tokenizer( + self.model_config.tokenizer_path + ) # Launch processes for multi-node tensor parallelism if self.parallel_config.nnodes > 1: if self.parallel_config.node_rank != 0: - tp_size_local = self.parallel_config.tp_size // self.parallel_config.nnodes + tp_size_local = ( + self.parallel_config.tp_size // self.parallel_config.nnodes + ) gpu_ids = [ - i for _ in range(self.parallel_config.nnodes) for i in range(tp_size_local) + i + for _ in range(self.parallel_config.nnodes) + for i in range(tp_size_local) ] tp_rank_range = list( range( @@ -83,19 +89,18 @@ def __init__(self, procs = launch_tp_servers( gpu_ids, tp_rank_range, - server_args, - self.parallel_config.nccl_ports[0], - self.model_config.model_overide_args, + self.model_config, + self.parallel_config, + self.schedule_config, + self.optimization_config, ) while True: pass - # Launch processes - self.tokenizer_manager = TokenizerManager(server_args, - self.model_config.tokenizer_port, - self.model_config.controller_port, - self.model_config.model_override_args) + self.tokenizer_manager = TokenizerManager( + self.model_config, self.observability_config + ) pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) @@ -105,20 +110,20 @@ def __init__(self, start_process = start_controller_process_multi self.proc_controller = mp.Process( target=start_process, - args=(server_args, - self.model_config.controller_port, - self.model_config.detokenizer_port, - self.parallel_config.nccl_ports, - pipe_controller_writer, - self.model_config.model_override_args), + args=( + self.model_config, + self.parallel_config, + self.schedule_config, + self.optimization_config, + self.observability_config, + pipe_controller_writer, + ), ) self.proc_controller.start() self.proc_detoken = mp.Process( target=start_detokenizer_process, args=( - server_args, - self.model_config.tokenizer_port, - self.model_config.detokenizer_port, + self.model_config, pipe_detoken_writer, ), ) @@ -147,7 +152,7 @@ def shutdown(self): # controller to bring it down, then tokenizer_manager shutdown itself accordingly. self.tokenizer_manager.shutdown() - # 确保子进程正确终止 + # Shutdown the MP processes for proc in [self.proc_controller, self.proc_detoken]: if proc.is_alive(): proc.terminate() @@ -159,34 +164,37 @@ def from_engine_args( engine_args: EngineArgs, ) -> "Engine": """Creates an LLM engine from the engine arguments.""" - parser = argparse.ArgumentParser() - ServerArgs.add_cli_args(parser) - args = parser.parse_args() - args.model_path = engine_args.model_path - server_args = ServerArgs.from_cli_args(args) - server_args.disable_cuda_graph = True engine_config = engine_args.create_engine_config() engine = cls( - server_args, **engine_config.to_dict(), ) return engine + class LLM: - def __init__(self, - model: str, - tokenizer: Optional[str] = None, - tokenizer_mode: str = "auto", - skip_tokenizer_init: bool = False, - trust_remote_code: bool = True, - tensor_parallel_size: int = 1, - dtype: str = "auto", - quantization: Optional[str] = None, - seed: int = 0, - context_length: Optional[int] = None, - **kwargs, - ) -> None: + def __init__( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = True, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + seed: int = 0, + context_length: Optional[int] = None, + **kwargs, + ) -> None: + engine_arg_fields = {field.name for field in fields(EngineArgs)} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in engine_arg_fields} + + # Warn about any extra kwargs + extra_kwargs = {k: v for k, v in kwargs.items() if k not in engine_arg_fields} + if extra_kwargs: + logger.warn(f"Warning: Ignored unexpected kwargs: {extra_kwargs}") + engine_args = EngineArgs( model_path=model, tokenizer_path=tokenizer, @@ -198,14 +206,16 @@ def __init__(self, quantization=quantization, random_seed=seed, context_length=context_length, - **kwargs, + **filtered_kwargs, ) self.llm_engine = Engine.from_engine_args(engine_args) def generate( self, prompts: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Union["SamplingParams", List["SamplingParams"]]] = None, + sampling_params: Optional[ + Union["SamplingParams", List["SamplingParams"]] + ] = None, prompt_token_ids: Optional[Union[List[List[int]], List[int]]] = None, ): if prompts is None and prompt_token_ids is None: @@ -221,7 +231,7 @@ def generate( gen_req_input = GenerateReqInput( text=prompts, input_ids=prompt_token_ids, - sampling_params=sampling_params_dicts + sampling_params=sampling_params_dicts, ) try: @@ -240,10 +250,12 @@ def generate( async def _generate_async_helper(self, gen_req_input, request): results = [] - async for response in self.llm_engine.tokenizer_manager.generate_request(gen_req_input, request): + async for response in self.llm_engine.tokenizer_manager.generate_request( + gen_req_input, request + ): if isinstance(response, list): # if gen_req_input is a list input, it is deemed a batched input, the response is alread a list results.extend(response) else: results.append(response) - return results \ No newline at end of file + return results diff --git a/python/sglang/srt/serving/engine_args.py b/python/sglang/srt/serving/engine_args.py index 5fcc607b9a3..6dcf0621213 100644 --- a/python/sglang/srt/serving/engine_args.py +++ b/python/sglang/srt/serving/engine_args.py @@ -16,26 +16,26 @@ """The LLM engine arguments.""" import dataclasses -import os import logging +import os import random -from typing import Optional, Union, List +from typing import List, Optional, Union -from sglang.srt.server_args import PortArgs -from sglang.srt.constrained import disable_cache from sglang.srt.config import ( EngineConfig, ModelConfig, + ObservabilityConfig, + OptimizationConfig, ParallelConfig, ScheduleConfig, - OptimizationConfig, ) +from sglang.srt.constrained import disable_cache from sglang.srt.utils import ( + allocate_init_ports, assert_pkg_version, enable_show_time_cost, maybe_set_triton_cache_manager, set_ulimit, - allocate_init_ports ) logger = logging.getLogger(__name__) @@ -54,10 +54,14 @@ class EngineArgs: context_length: Optional[int] = None quantization: Optional[str] = None served_model_name: Optional[str] = None - chat_template: Optional[str] = None + random_seed: Optional[int] = None + stream_interval: int = 1 + tokenizer_port: Optional[int] = 0 + detokenizer_port: Optional[int] = 0 + controller_port: Optional[int] = 0 model_override_args: Optional[dict] = None - # Memory and scheduling + # Scheduling mem_fraction_static: Optional[float] = None max_running_requests: Optional[int] = None max_num_reqs: Optional[int] = None @@ -71,41 +75,32 @@ class EngineArgs: tp_size: int = 1 dp_size: int = 1 load_balance_method: str = "round_robin" - - # Distributed args nccl_init_addr: Optional[str] = None nccl_ports: Optional[List[int]] = None additional_ports: Optional[Union[List[int], int]] = None - tokenizer_port: Optional[int] = 0 - detokenizer_port: Optional[int] = 0 - controller_port: Optional[int] = 0 nnodes: int = 1 node_rank: Optional[int] = None - - # Other - file_storage_pth: str = "SGLang_storage" - stream_interval: int = 1 - random_seed: Optional[int] = None - - # Logging - log_level: str = "info" - log_level_http: Optional[str] = None - log_requests: bool = False - show_time_cost: bool = False - # Optimization/debug options + # Optimization disable_flashinfer: bool = False disable_flashinfer_sampling: bool = False disable_radix_cache: bool = False disable_regex_jump_forward: bool = False disable_cuda_graph: bool = False disable_disk_cache: bool = False + enable_mixed_chunk: bool = False enable_torch_compile: bool = False enable_p2p_check: bool = False enable_mla: bool = False attention_reduce_in_fp32: bool = False efficient_weight_load: bool = False + # Observability + log_level: str = "info" + log_level_http: Optional[str] = None + log_requests: bool = False + show_time_cost: bool = False + def __post_init__(self): if self.tokenizer_path is None: self.tokenizer_path = self.model_path @@ -133,10 +128,14 @@ def __post_init__(self): self.random_seed = random.randint(0, 1 << 30) self._check_args() - self._set_envs_and_config() - def create_engine_config(self,) -> EngineConfig: self._alloc_port_args() + + self._set_envs_and_config() + + def create_engine_config( + self, + ) -> EngineConfig: model_config = ModelConfig( model_path=self.model_path, load_format=self.load_format, @@ -153,7 +152,7 @@ def create_engine_config(self,) -> EngineConfig: tokenizer_port=self.tokenizer_port, detokenizer_port=self.detokenizer_port, controller_port=self.controller_port, - model_override_args=self.model_override_args + model_override_args=self.model_override_args, ) schedule_config = ScheduleConfig( mem_fraction_static=self.mem_fraction_static, @@ -163,7 +162,7 @@ def create_engine_config(self,) -> EngineConfig: chunked_prefill_size=self.chunked_prefill_size, max_prefill_tokens=self.max_prefill_tokens, schedule_policy=self.schedule_policy, - schedule_conservativeness=self.schedule_conservativeness + schedule_conservativeness=self.schedule_conservativeness, ) parallel_config = ParallelConfig( tp_size=self.tp_size, @@ -173,7 +172,7 @@ def create_engine_config(self,) -> EngineConfig: nccl_ports=self.nccl_ports, additional_ports=self.additional_ports, nnodes=self.nnodes, - node_rank=self.node_rank + node_rank=self.node_rank, ) optimization_config = OptimizationConfig( disable_flashinfer=self.disable_flashinfer, @@ -184,29 +183,44 @@ def create_engine_config(self,) -> EngineConfig: disable_disk_cache=self.disable_disk_cache, enable_torch_compile=self.enable_torch_compile, enable_p2p_check=self.enable_p2p_check, + enable_mixed_chunk=self.enable_mixed_chunk, enable_mla=self.enable_mla, attention_reduce_in_fp32=self.attention_reduce_in_fp32, - efficient_weight_load=self.efficient_weight_load + efficient_weight_load=self.efficient_weight_load, + ) + observability_config = ObservabilityConfig( + log_level=self.log_level, + log_level_http=self.log_level_http, + log_requests=self.log_requests, + show_time_cost=self.show_time_cost, ) return EngineConfig( model_config=model_config, schedule_config=schedule_config, parallel_config=parallel_config, optimization_config=optimization_config, + observability_config=observability_config, ) def _alloc_port_args(self): + if isinstance(self.additional_ports, int): + self.additional_ports = [self.additional_ports] + elif self.additional_ports is None: + self.additional_ports = [] + _, ports = allocate_init_ports( 30000, - None, + self.additional_ports, self.dp_size, ) self.tokenizer_port = ports[0] self.controller_port = ports[1] self.detokenizer_port = ports[2] self.nccl_ports = ports[3:] - logger.info(f"Allocated port args: tokenizer_port({self.tokenizer_port}), controller_port({self.controller_port})," - "detokenizer_port({self.detokenizer_port}), nccl_ports({self.nccl_ports})") + logger.info( + f"Allocated port args: tokenizer_port({self.tokenizer_port}), controller_port({self.controller_port})," + f"detokenizer_port({self.detokenizer_port}), nccl_ports({self.nccl_ports})" + ) def _check_args(self): assert ( diff --git a/python/sglang/srt/serving/server.py b/python/sglang/srt/serving/server.py index 5b95e821ad1..0fb88d5cc57 100644 --- a/python/sglang/srt/serving/server.py +++ b/python/sglang/srt/serving/server.py @@ -34,7 +34,6 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None) import aiohttp -import psutil import requests import uvicorn import uvloop @@ -57,13 +56,13 @@ v1_retrieve_file_content, ) from sglang.srt.openai_api.protocol import ModelCard, ModelList +from sglang.srt.serving.engine import Engine, EngineArgs from sglang.srt.serving.server_args import ServerArgs from sglang.srt.utils import ( add_api_key_middleware, allocate_init_ports, kill_child_process, ) -from sglang.srt.serving import Engine from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -74,6 +73,9 @@ app = FastAPI() engine: Engine = None +# for OpenAI files API +file_storage_pth: str + @app.get("/health") async def health() -> Response: @@ -111,7 +113,9 @@ async def generate_request(obj: GenerateReqInput, request: Request): async def stream_results(): try: - async for out in engine.tokenizer_manager.generate_request(obj, request): + async for out in engine.tokenizer_manager.generate_request( + obj, request + ): yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" except ValueError as e: out = {"error": {"message": str(e)}} @@ -125,7 +129,9 @@ async def stream_results(): ) else: try: - ret = await engine.tokenizer_manager.generate_request(obj, request).__anext__() + ret = await engine.tokenizer_manager.generate_request( + obj, request + ).__anext__() return ret except ValueError as e: return JSONResponse( @@ -180,9 +186,7 @@ def available_models(): @app.post("/v1/files") async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): - return await v1_files_create( - file, purpose, engine.tokenizer_manager.server_args.file_storage_pth - ) + return await v1_files_create(file, purpose, file_storage_pth) @app.delete("/v1/files/{file_id}") @@ -215,7 +219,6 @@ async def retrieve_file_content(file_id: str): def launch_server( server_args: ServerArgs, - model_overide_args: Optional[dict] = None, pipe_finish_writer: Optional[mp.connection.Connection] = None, ): """Launch an HTTP server.""" @@ -229,7 +232,13 @@ def launch_server( engine = Engine.from_engine_args(server_args.engine_args) if server_args.chat_template: - load_chat_template_for_openai_api(engine.tokenizer_manager, server_args.chat_template) + load_chat_template_for_openai_api( + engine.tokenizer_manager, server_args.chat_template + ) + + if server_args.file_storage_pth: + global file_storage_pth + file_storage_pth = server_args.file_storage_pth # Add api key authorization if server_args.api_key: @@ -254,6 +263,7 @@ def launch_server( finally: t.join() + def _wait_and_warmup(server_args, pipe_finish_writer): headers = {} url = server_args.url() @@ -334,18 +344,16 @@ class Runtime: def __init__( self, log_level: str = "error", - model_overide_args: Optional[dict] = None, + model_override_args: Optional[dict] = None, *args, **kwargs, ): """See the arguments in server_args.py::ServerArgs""" - self.server_args = ServerArgs(*args, log_level=log_level, **kwargs) - - # Pre-allocate ports - self.server_args.port, self.server_args.additional_ports = allocate_init_ports( - self.server_args.port, - self.server_args.additional_ports, - self.server_args.dp_size, + self.server_args = ServerArgs.from_kwargs( + *args, + log_level=log_level, + model_override_args=model_override_args, + **kwargs, ) self.url = self.server_args.url() @@ -357,7 +365,7 @@ def __init__( pipe_reader, pipe_writer = mp.Pipe(duplex=False) proc = mp.Process( target=launch_server, - args=(self.server_args, model_overide_args, pipe_writer), + args=(self.server_args, pipe_writer), ) proc.start() pipe_writer.close() diff --git a/python/sglang/srt/serving/server_args.py b/python/sglang/srt/serving/server_args.py index 1756b9692b6..567aa31d085 100644 --- a/python/sglang/srt/serving/server_args.py +++ b/python/sglang/srt/serving/server_args.py @@ -18,8 +18,8 @@ import argparse import dataclasses import logging -import random -from typing import List, Optional, Union +from dataclasses import fields +from typing import Dict, List, Optional, Union from sglang.srt.serving.engine_args import EngineArgs @@ -29,7 +29,7 @@ @dataclasses.dataclass class ServerArgs: # The core engine args - engine_args: EngineArgs + engine_args: EngineArgs # = field(default_factory=EngineArgs) # # The server specifc args @@ -37,25 +37,54 @@ class ServerArgs: # Connection host: str = "127.0.0.1" port: int = 30000 - additional_ports: Optional[Union[List[int], int]] = None - # Model and tokenizer + # OpenAI API chat_template: Optional[str] = None file_storage_pth: str = "SGLang_storage" - # Log - log_level: str = "info" - log_level_http: Optional[str] = None - log_requests: bool = False, - # Authentication api_key: Optional[str] = None - def __post_init__(self): - if isinstance(self.additional_ports, int): - self.additional_ports = [self.additional_ports] - elif self.additional_ports is None: - self.additional_ports = [] + def __post_init__(self): ... + + def __getattr__(self, item): + # Forward attribute access to engine_args if not found in ServerArgs. + # For attribute in server_args, it will be found in ServerArgs's __dict__ + # and no entry into this function. + if hasattr(self.engine_args, item): + return getattr(self.engine_args, item) + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + + def __setattr__(self, key, value): + # If the attribute exists in ServerArgs, set it directly + if key in {f.name for f in fields(ServerArgs)}: + super().__setattr__(key, value) + # If the attribute exists in EngineArgs, forward it to engine_args + elif hasattr(self.engine_args, key): + setattr(self.engine_args, key, value) + else: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{key}'" + ) + + @classmethod + def from_kwargs( + cls, + *args, + **kwargs: Dict[str, any], + ) -> "ServerArgs": + """Creates a ServerArgs instance by separating EngineArgs and ServerArgs parameters.""" + engine_args_fields = {field.name for field in fields(EngineArgs)} + server_args_fields = {field.name for field in fields(cls)} - {"engine_args"} + + engine_args_dict = {k: v for k, v in kwargs.items() if k in engine_args_fields} + server_args_dict = {k: v for k, v in kwargs.items() if k in server_args_fields} + + engine_args = EngineArgs(*args, **engine_args_dict) + + return cls(engine_args=engine_args, **server_args_dict) @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -164,7 +193,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--chat-template", type=str, - default=EngineArgs.chat_template, + default=ServerArgs.chat_template, help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.", ) parser.add_argument( @@ -238,13 +267,13 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--log-level", type=str, - default=ServerArgs.log_level, + default=EngineArgs.log_level, help="The logging level of all loggers.", ) parser.add_argument( "--log-level-http", type=str, - default=ServerArgs.log_level_http, + default=EngineArgs.log_level_http, help="The logging level of HTTP server. If not set, reuse --log-level by default.", ) parser.add_argument( @@ -331,6 +360,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", ) + parser.add_argument( + "--enable-mixed-chunk", + action="store_true", + help="Enabling mixing prefill and decode in a chunked batch.", + ) parser.add_argument( "--enable-torch-compile", action="store_true", @@ -365,20 +399,20 @@ def from_cli_args(cls, args: argparse.Namespace): # Init EngineArgs engine_args_fields = {field.name for field in dataclasses.fields(EngineArgs)} - engine_args_dict = {key: getattr(args, key) for key in engine_args_fields if hasattr(args, key)} + engine_args_dict = { + key: getattr(args, key) for key in engine_args_fields if hasattr(args, key) + } engine_args = EngineArgs(**engine_args_dict) - + # Init ServerArgs with the remaining fields... - server_args_fields = {field.name for field in dataclasses.fields(cls)} - {'engine_args'} - server_args_dict = {key: getattr(args, key) for key in server_args_fields if hasattr(args, key)} - + server_args_fields = {field.name for field in dataclasses.fields(cls)} - { + "engine_args" + } + server_args_dict = { + key: getattr(args, key) for key in server_args_fields if hasattr(args, key) + } + return cls(engine_args=engine_args, **server_args_dict) - - def __getattr__(self, item): - # Forward attribute access to engine_args if not found in ServerArgs - if hasattr(self.engine_args, item): - return getattr(self.engine_args, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") - + def url(self): - return f"http://{self.host}:{self.port}" \ No newline at end of file + return f"http://{self.host}:{self.port}" diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 9761c851a52..3a95f4828a1 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -209,14 +209,14 @@ def get_int_token_logit_bias(tokenizer, vocab_size): def is_multimodal_model(model): - from sglang.srt.model_config import ModelConfig + from sglang.srt.config import ModelConfig if isinstance(model, str): model = model.lower() return "llava" in model or "yi-vl" in model or "llava-next" in model if isinstance(model, ModelConfig): - model_path = model.path.lower() + model_path = model.model_path.lower() return ( "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path ) @@ -412,15 +412,21 @@ def monkey_patch_vllm_dummy_weight_loader(): Monkey patch the dummy weight loader in vllm to call process_weights_after_loading. """ + from vllm.model_executor.model_loader.loader import CacheConfig as VllmCacheConfig + from vllm.model_executor.model_loader.loader import DeviceConfig as VllmDeviceConfig + from vllm.model_executor.model_loader.loader import DummyModelLoader + from vllm.model_executor.model_loader.loader import LoRAConfig as VllmLoRAConfig + from vllm.model_executor.model_loader.loader import ModelConfig as VllmModelConfig + from vllm.model_executor.model_loader.loader import ( + MultiModalConfig as VllmMultiModalConfig, + ) + from vllm.model_executor.model_loader.loader import ( + ParallelConfig as VllmParallelConfig, + ) + from vllm.model_executor.model_loader.loader import ( + SchedulerConfig as VllmSchedulerConfig, + ) from vllm.model_executor.model_loader.loader import ( - CacheConfig, - DeviceConfig, - DummyModelLoader, - LoRAConfig, - ModelConfig, - MultiModalConfig, - ParallelConfig, - SchedulerConfig, _initialize_model, initialize_dummy_weights, nn, @@ -430,13 +436,13 @@ def monkey_patch_vllm_dummy_weight_loader(): def load_model( self, *, - model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, + model_config: VllmModelConfig, + device_config: VllmDeviceConfig, + lora_config: Optional[VllmLoRAConfig], + multimodal_config: Optional[VllmMultiModalConfig], + parallel_config: VllmParallelConfig, + scheduler_config: VllmSchedulerConfig, + cache_config: VllmCacheConfig, ) -> nn.Module: with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index e325ecb710e..adc690dc4d6 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -23,7 +23,7 @@ import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer -from sglang.srt.server import Runtime +from sglang.srt.serving.server import Runtime from sglang.srt.utils import is_generation_model DEFAULT_PROMPTS = [ diff --git a/test/srt/test_moe_serving_throughput.py b/test/srt/test_moe_serving_throughput.py index bbcd5122769..ec0c67c9c14 100644 --- a/test/srt/test_moe_serving_throughput.py +++ b/test/srt/test_moe_serving_throughput.py @@ -3,7 +3,7 @@ from types import SimpleNamespace from sglang.bench_serving import run_benchmark -from sglang.srt.server_args import ServerArgs +from sglang.srt.serving.engine_args import EngineArgs from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MOE_MODEL_NAME_FOR_TEST, @@ -66,9 +66,9 @@ def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size def test_default(self): res = self.run_test( - disable_radix_cache=ServerArgs.disable_radix_cache, - disable_flashinfer=ServerArgs.disable_flashinfer, - chunked_prefill_size=ServerArgs.chunked_prefill_size, + disable_radix_cache=EngineArgs.disable_radix_cache, + disable_flashinfer=EngineArgs.disable_flashinfer, + chunked_prefill_size=EngineArgs.chunked_prefill_size, ) if os.getenv("SGLANG_IS_IN_CI", "false") == "true": @@ -78,8 +78,8 @@ def test_default(self): def test_default_without_radix_cache(self): res = self.run_test( disable_radix_cache=True, - disable_flashinfer=ServerArgs.disable_flashinfer, - chunked_prefill_size=ServerArgs.chunked_prefill_size, + disable_flashinfer=EngineArgs.disable_flashinfer, + chunked_prefill_size=EngineArgs.chunked_prefill_size, ) if os.getenv("SGLANG_IS_IN_CI", "false") == "true": @@ -88,8 +88,8 @@ def test_default_without_radix_cache(self): def test_default_without_chunked_prefill(self): res = self.run_test( - disable_radix_cache=ServerArgs.disable_radix_cache, - disable_flashinfer=ServerArgs.disable_flashinfer, + disable_radix_cache=EngineArgs.disable_radix_cache, + disable_flashinfer=EngineArgs.disable_flashinfer, chunked_prefill_size=-1, ) diff --git a/test/srt/test_serving_throughput.py b/test/srt/test_serving_throughput.py index 261ac6ec52f..cfdb6632058 100644 --- a/test/srt/test_serving_throughput.py +++ b/test/srt/test_serving_throughput.py @@ -3,7 +3,7 @@ from types import SimpleNamespace from sglang.bench_serving import run_benchmark -from sglang.srt.server_args import ServerArgs +from sglang.srt.serving.engine_args import EngineArgs from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -64,9 +64,9 @@ def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size def test_default(self): res = self.run_test( - disable_radix_cache=ServerArgs.disable_radix_cache, - disable_flashinfer=ServerArgs.disable_flashinfer, - chunked_prefill_size=ServerArgs.chunked_prefill_size, + disable_radix_cache=EngineArgs.disable_radix_cache, + disable_flashinfer=EngineArgs.disable_flashinfer, + chunked_prefill_size=EngineArgs.chunked_prefill_size, ) if os.getenv("SGLANG_IS_IN_CI", "false") == "true": @@ -76,8 +76,8 @@ def test_default(self): def test_default_without_radix_cache(self): res = self.run_test( disable_radix_cache=True, - disable_flashinfer=ServerArgs.disable_flashinfer, - chunked_prefill_size=ServerArgs.chunked_prefill_size, + disable_flashinfer=EngineArgs.disable_flashinfer, + chunked_prefill_size=EngineArgs.chunked_prefill_size, ) if os.getenv("SGLANG_IS_IN_CI", "false") == "true": @@ -86,8 +86,8 @@ def test_default_without_radix_cache(self): def test_default_without_chunked_prefill(self): res = self.run_test( - disable_radix_cache=ServerArgs.disable_radix_cache, - disable_flashinfer=ServerArgs.disable_flashinfer, + disable_radix_cache=EngineArgs.disable_radix_cache, + disable_flashinfer=EngineArgs.disable_flashinfer, chunked_prefill_size=-1, )