Skip to content

Commit 0d67d1c

Browse files
committed
rebase
2 parents 7336d41 + 29e30f5 commit 0d67d1c

File tree

11 files changed

+169
-54
lines changed

11 files changed

+169
-54
lines changed

components/backends/sglang/README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
2323

2424
## Table of Contents
2525
- [Feature Support Matrix](#feature-support-matrix)
26+
- [Dynamo SGLang Integration](#dynamo-sglang-integration)
2627
- [Quick Start](#quick-start)
2728
- [Single Node Examples](#run-single-node-examples)
2829
- [Multi-Node and Advanced Examples](#advanced-examples)
@@ -50,6 +51,31 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
5051
| **GB200 Support** || |
5152

5253

54+
## Dynamo SGLang Integration
55+
56+
Dynamo SGLang integrates SGLang engines into Dynamo's distributed runtime, enabling advanced features like disaggregated serving, KV-aware routing, and request migration while maintaining full compatibility with SGLang's engine arguments.
57+
58+
### Argument Handling
59+
60+
Dynamo SGLang uses SGLang's native argument parser, so **most SGLang engine arguments work identically**. You can pass any SGLang argument (like `--model-path`, `--tp`, `--trust-remote-code`) directly to `dynamo.sglang`.
61+
62+
#### Dynamo-Specific Arguments
63+
64+
| Argument | Description | Default | SGLang Equivalent |
65+
|----------|-------------|---------|-------------------|
66+
| `--endpoint` | Dynamo endpoint in `dyn://namespace.component.endpoint` format | Auto-generated based on mode | N/A |
67+
| `--migration-limit` | Max times a request can migrate between workers | `0` (disabled) | N/A |
68+
| `--dyn-tool-call-parser` | Tool call parser for structured outputs (takes precedence over `--tool-call-parser`) | `None` | `--tool-call-parser` |
69+
| `--dyn-reasoning-parser` | Reasoning parser for CoT models (takes precedence over `--reasoning-parser`) | `None` | `--reasoning-parser` |
70+
| `--use-sglang-tokenizer` | Use SGLang's tokenizer instead of Dynamo's | `False` | N/A |
71+
72+
#### Tokenizer Behavior
73+
74+
- **Default (`--use-sglang-tokenizer` not set)**: Dynamo handles tokenization and passes `input_ids` to SGLang
75+
- **With `--use-sglang-tokenizer`**: SGLang handles tokenization, Dynamo passes raw prompts
76+
77+
> **Note**: When using `--use-sglang-tokenizer`, only `v1/chat/completions` endpoints are available through Dynamo's frontend.
78+
5379
## SGLang Quick Start
5480

5581
Below we provide a guide that lets you run all of our common deployment patterns on a single node.

components/backends/sglang/launch/agg.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,4 @@ python3 -m dynamo.sglang \
2424
--served-model-name Qwen/Qwen3-0.6B \
2525
--page-size 16 \
2626
--tp 1 \
27-
--trust-remote-code \
28-
--skip-tokenizer-init
27+
--trust-remote-code

components/backends/sglang/launch/agg_router.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ python3 -m dynamo.sglang \
2525
--page-size 16 \
2626
--tp 1 \
2727
--trust-remote-code \
28-
--skip-tokenizer-init \
2928
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5557"}' &
3029
WORKER_PID=$!
3130

@@ -35,5 +34,4 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \
3534
--page-size 16 \
3635
--tp 1 \
3736
--trust-remote-code \
38-
--skip-tokenizer-init \
3937
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5558"}'

components/backends/sglang/launch/disagg.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ python3 -m dynamo.sglang \
2525
--page-size 16 \
2626
--tp 1 \
2727
--trust-remote-code \
28-
--skip-tokenizer-init \
2928
--disaggregation-mode prefill \
29+
--disaggregation-bootstrap-port 12345 \
30+
--host 0.0.0.0 \
3031
--disaggregation-transfer-backend nixl &
3132
PREFILL_PID=$!
3233

@@ -37,6 +38,7 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \
3738
--page-size 16 \
3839
--tp 1 \
3940
--trust-remote-code \
40-
--skip-tokenizer-init \
4141
--disaggregation-mode decode \
42+
--disaggregation-bootstrap-port 12345 \
43+
--host 0.0.0.0 \
4244
--disaggregation-transfer-backend nixl

components/backends/sglang/launch/disagg_dp_attn.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ python3 -m dynamo.sglang \
3030
--dp-size 2 \
3131
--enable-dp-attention \
3232
--trust-remote-code \
33-
--skip-tokenizer-init \
3433
--disaggregation-mode prefill \
3534
--disaggregation-transfer-backend nixl \
3635
--expert-distribution-recorder-mode stat \
@@ -45,7 +44,6 @@ CUDA_VISIBLE_DEVICES=2,3 python3 -m dynamo.sglang \
4544
--dp-size 2 \
4645
--enable-dp-attention \
4746
--trust-remote-code \
48-
--skip-tokenizer-init \
4947
--disaggregation-mode decode \
5048
--disaggregation-transfer-backend nixl \
5149
--expert-distribution-recorder-mode stat \

components/backends/sglang/src/dynamo/sglang/args.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@
4949
"choices": get_reasoning_parser_names(),
5050
"help": "Reasoning parser name for the model. If not specified, no reasoning parsing is performed.",
5151
},
52+
"use-sglang-tokenizer": {
53+
"flags": ["--use-sglang-tokenizer"],
54+
"action": "store_true",
55+
"default": False,
56+
"help": "Use SGLang's tokenizer. This will skip tokenization of the input and output and only v1/chat/completions will be available when using the dynamo frontend",
57+
},
5258
}
5359

5460

@@ -63,6 +69,9 @@ class DynamoArgs:
6369
tool_call_parser: Optional[str] = None
6470
reasoning_parser: Optional[str] = None
6571

72+
# preprocessing options
73+
use_sglang_tokenizer: bool = False
74+
6675

6776
class DisaggregationMode(Enum):
6877
AGGREGATED = "agg"
@@ -127,13 +136,18 @@ def parse_args(args: list[str]) -> Config:
127136

128137
# Dynamo args
129138
for info in DYNAMO_ARGS.values():
130-
parser.add_argument(
131-
*info["flags"],
132-
type=info["type"],
133-
default=info["default"] if "default" in info else None,
134-
help=info["help"],
135-
choices=info.get("choices", None),
136-
)
139+
kwargs = {
140+
"default": info["default"] if "default" in info else None,
141+
"help": info["help"],
142+
}
143+
if "type" in info:
144+
kwargs["type"] = info["type"]
145+
if "choices" in info:
146+
kwargs["choices"] = info["choices"]
147+
if "action" in info:
148+
kwargs["action"] = info["action"]
149+
150+
parser.add_argument(*info["flags"], **kwargs)
137151

138152
# SGLang args
139153
bootstrap_port = _reserve_disaggregation_bootstrap_port()
@@ -191,15 +205,20 @@ def parse_args(args: list[str]) -> Config:
191205
migration_limit=parsed_args.migration_limit,
192206
tool_call_parser=tool_call_parser,
193207
reasoning_parser=reasoning_parser,
208+
use_sglang_tokenizer=parsed_args.use_sglang_tokenizer,
194209
)
195210
logging.debug(f"Dynamo args: {dynamo_args}")
196211

197212
server_args = ServerArgs.from_cli_args(parsed_args)
198213

199-
if not server_args.skip_tokenizer_init:
200-
logging.warning(
201-
"When using the dynamo frontend (python3 -m dynamo.frontend), we perform tokenization and detokenization "
202-
"in the frontend. Automatically setting --skip-tokenizer-init to True."
214+
if parsed_args.use_sglang_tokenizer:
215+
logging.info(
216+
"Using SGLang's built in tokenizer. Setting skip_tokenizer_init to False"
217+
)
218+
server_args.skip_tokenizer_init = False
219+
else:
220+
logging.info(
221+
"Using dynamo's built in tokenizer. Setting skip_tokenizer_init to True"
203222
)
204223
server_args.skip_tokenizer_init = True
205224

components/backends/sglang/src/dynamo/sglang/protocol.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from typing import List, Optional
4+
from typing import List, Optional, Union
55

66
from pydantic import BaseModel, Field
7+
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
78

89
TokenIdType = int
910

@@ -35,7 +36,6 @@ class SamplingOptions(BaseModel):
3536

3637
class PreprocessedRequest(BaseModel):
3738
token_ids: List[TokenIdType]
38-
batch_token_ids: Optional[List[List[TokenIdType]]] = None
3939
stop_conditions: StopConditions
4040
sampling_options: SamplingOptions
4141
eos_token_ids: List[TokenIdType] = Field(default_factory=list)
@@ -44,6 +44,6 @@ class PreprocessedRequest(BaseModel):
4444

4545

4646
class DisaggPreprocessedRequest(BaseModel):
47-
request: PreprocessedRequest
47+
request: Union[PreprocessedRequest, ChatCompletionRequest]
4848
sampling_params: dict
4949
data_parallel_rank: Optional[int] = None

components/backends/sglang/src/dynamo/sglang/register.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,18 @@ async def register_llm_with_runtime_config(
2424
bool: True if registration succeeded, False if it failed
2525
"""
2626
runtime_config = await _get_runtime_config(engine, dynamo_args)
27+
input_type = ModelInput.Tokens
28+
output_type = ModelType.Chat | ModelType.Completions
29+
if not server_args.skip_tokenizer_init:
30+
logging.warning(
31+
"The skip-tokenizer-init flag was not set. Using the sglang tokenizer/detokenizer instead. The dynamo tokenizer/detokenizer will not be used and only v1/chat/completions will be available"
32+
)
33+
input_type = ModelInput.Text
34+
output_type = ModelType.Chat
2735
try:
2836
await register_llm(
29-
ModelInput.Tokens,
30-
ModelType.Chat | ModelType.Completions,
37+
input_type,
38+
output_type,
3139
endpoint,
3240
server_args.model_path,
3341
server_args.served_model_name,

components/backends/sglang/src/dynamo/sglang/request_handlers/decode_handler.py

Lines changed: 79 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import logging
5+
import time
56

67
import sglang as sgl
78

@@ -41,20 +42,33 @@ def cleanup(self):
4142
super().cleanup()
4243

4344
def _build_sampling_params(self, request: dict) -> dict:
44-
sampling_params = {}
45-
if request["sampling_options"]["temperature"]:
46-
sampling_params["temperature"] = request["sampling_options"]["temperature"]
47-
if request["sampling_options"]["top_p"]:
48-
sampling_params["top_p"] = request["sampling_options"]["top_p"]
49-
if request["sampling_options"]["top_k"]:
50-
sampling_params["top_k"] = request["sampling_options"]["top_k"]
51-
sampling_params["max_new_tokens"] = request["stop_conditions"]["max_tokens"]
52-
if request["stop_conditions"]["ignore_eos"]:
53-
sampling_params["ignore_eos"] = request["stop_conditions"]["ignore_eos"]
54-
return sampling_params
45+
"""Build sampling params depending on request from frontend"""
46+
if self.skip_tokenizer_init:
47+
# Token-based request format
48+
sampling_opts = request.get("sampling_options", {})
49+
stop_conditions = request.get("stop_conditions", {})
50+
51+
param_mapping = {
52+
"temperature": sampling_opts.get("temperature"),
53+
"top_p": sampling_opts.get("top_p"),
54+
"top_k": sampling_opts.get("top_k"),
55+
"max_new_tokens": stop_conditions.get("max_tokens"),
56+
"ignore_eos": stop_conditions.get("ignore_eos"),
57+
}
58+
else:
59+
# OpenAI request format
60+
param_mapping = {
61+
"temperature": request.get("temperature"),
62+
"top_p": request.get("top_p"),
63+
"top_k": request.get("top_k"),
64+
"max_new_tokens": request.get("max_tokens"),
65+
}
66+
67+
return {k: v for k, v in param_mapping.items() if v is not None}
5568

5669
async def generate(self, request: dict):
5770
sampling_params = self._build_sampling_params(request)
71+
input_param = self._get_input_param(request)
5872

5973
if self.serving_mode == DisaggregationMode.DECODE:
6074
# request the bootstrap info from the target prefill worker
@@ -74,41 +88,77 @@ async def generate(self, request: dict):
7488
raise RuntimeError("No bootstrap info received from prefill worker")
7589

7690
decode = await self.engine.async_generate(
77-
input_ids=request["token_ids"],
91+
**input_param,
7892
sampling_params=sampling_params,
7993
stream=True,
8094
bootstrap_host=bootstrap_info["bootstrap_host"],
8195
bootstrap_port=bootstrap_info["bootstrap_port"],
8296
bootstrap_room=bootstrap_info["bootstrap_room"],
8397
)
8498

85-
async for out in self._process_stream(decode):
86-
yield out
99+
if self.skip_tokenizer_init:
100+
async for out in self._process_token_stream(decode):
101+
yield out
102+
else:
103+
async for out in self._process_text_stream(decode):
104+
yield out
87105
else:
88106
agg = await self.engine.async_generate(
89-
input_ids=request["token_ids"],
107+
**input_param,
90108
sampling_params=sampling_params,
91109
stream=True,
92110
)
93-
async for out in self._process_stream(agg):
94-
yield out
95-
96-
async def _process_stream(self, stream_source):
111+
if self.skip_tokenizer_init:
112+
async for out in self._process_token_stream(agg):
113+
yield out
114+
else:
115+
async for out in self._process_text_stream(agg):
116+
yield out
117+
118+
async def _process_token_stream(self, stream_source):
97119
num_output_tokens_so_far = 0
98120

99121
async for res in stream_source:
100-
try:
101-
next_total_toks = len(res["output_ids"])
102-
except KeyError:
103-
raise ValueError(
104-
f"Missing 'output_ids' in response. This often happens when using skip_tokenizer_init=False. "
105-
f"If you're using ModelType.CHAT or custom model configurations, you may need to modify "
106-
f"the tokenization/detokenization logic in your handler. Response keys: {list(res.keys())}"
107-
)
108-
out = {"token_ids": res["output_ids"][num_output_tokens_so_far:]}
109-
num_output_tokens_so_far = next_total_toks
110122
finish_reason = res["meta_info"]["finish_reason"]
111123
if finish_reason:
112124
out = {"token_ids": [], "finish_reason": finish_reason["type"]}
125+
else:
126+
try:
127+
next_total_toks = len(res["output_ids"])
128+
except KeyError:
129+
raise ValueError(
130+
f"Missing 'output_ids' in response. Response keys: {list(res.keys())}"
131+
)
132+
out = {"token_ids": res["output_ids"][num_output_tokens_so_far:]}
133+
num_output_tokens_so_far = next_total_toks
113134

114135
yield out
136+
137+
async def _process_text_stream(self, stream_source):
138+
"""Process stream for text input mode"""
139+
count = 0
140+
141+
async for res in stream_source:
142+
index = res.get("index", 0)
143+
text = res.get("text", "")
144+
145+
finish_reason = res["meta_info"]["finish_reason"]
146+
finish_reason_type = finish_reason["type"] if finish_reason else None
147+
next_count = len(text)
148+
delta = text[count:]
149+
150+
choice_data = {
151+
"index": index,
152+
"delta": {"role": "assistant", "content": delta},
153+
"finish_reason": finish_reason_type,
154+
}
155+
156+
response = {
157+
"id": res["meta_info"]["id"],
158+
"created": int(time.time()),
159+
"choices": [choice_data],
160+
"model": self.config.server_args.served_model_name,
161+
"object": "chat.completion.chunk",
162+
}
163+
yield response
164+
count = next_count

components/backends/sglang/src/dynamo/sglang/request_handlers/handler_base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,23 @@ def __init__(
2727
self.kv_publisher = kv_publisher
2828
self.prefill_client = prefill_client
2929
self.serving_mode = config.serving_mode
30+
self.skip_tokenizer_init = config.server_args.skip_tokenizer_init
3031

3132
@abstractmethod
3233
async def generate(self, request: str):
3334
pass
3435

3536
def cleanup(self):
3637
pass
38+
39+
def _get_input_param(self, request: dict) -> dict:
40+
"""Get the appropriate input parameter for SGLang"""
41+
if self.skip_tokenizer_init:
42+
return {"input_ids": request["token_ids"]}
43+
else:
44+
# use sglang's chat templating itself but leave tokenization to the
45+
# interal engine's TokenizerManager
46+
prompt = self.engine.tokenizer_manager.tokenizer.apply_chat_template(
47+
request["messages"], tokenize=False, add_generation_prompt=True
48+
)
49+
return {"prompt": prompt}

0 commit comments

Comments
 (0)