Skip to content

Commit a6884b2

Browse files
committed
fix: add missing max_num_seqs metrics for SGLang and TensorRT-LLM backends
SGLang: get max_num_seqs from server_args since SGLang separates config from runtime stats TensorRT-LLM: populate max_num_seqs and max_num_batched_tokens from config for metrics consistency
1 parent a6a457a commit a6884b2

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

components/backends/sglang/src/dynamo/sglang/register.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ async def register_llm_with_runtime_config(
2323
Returns:
2424
bool: True if registration succeeded, False if it failed
2525
"""
26-
runtime_config = await _get_runtime_config(engine, dynamo_args)
26+
runtime_config = await _get_runtime_config(engine, server_args, dynamo_args)
2727
input_type = ModelInput.Tokens
2828
output_type = ModelType.Chat | ModelType.Completions
2929
if not server_args.skip_tokenizer_init:
@@ -51,13 +51,25 @@ async def register_llm_with_runtime_config(
5151

5252

5353
async def _get_runtime_config(
54-
engine: sgl.Engine, dynamo_args: DynamoArgs
54+
engine: sgl.Engine, server_args: ServerArgs, dynamo_args: DynamoArgs
5555
) -> Optional[ModelRuntimeConfig]:
5656
"""Get runtime config from SGLang engine"""
5757
runtime_config = ModelRuntimeConfig()
5858
# set reasoning parser and tool call parser
5959
runtime_config.reasoning_parser = dynamo_args.reasoning_parser
6060
runtime_config.tool_call_parser = dynamo_args.tool_call_parser
61+
62+
# In SGLang, these are server_args, not scheduler_info (unlike vLLM)
63+
# Note: If --max-running-requests is not specified, SGLang uses an internal default
64+
# undocumented value. The value here will be None if not explicitly set by user.
65+
max_running_requests = getattr(server_args, "max_running_requests", None)
66+
if max_running_requests:
67+
runtime_config.max_num_seqs = max_running_requests
68+
69+
max_prefill_tokens = getattr(server_args, "max_prefill_tokens", None)
70+
if max_prefill_tokens:
71+
runtime_config.max_num_batched_tokens = max_prefill_tokens
72+
6173
try:
6274
# Try to check if the engine has a scheduler attribute with the computed values
6375
if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None:
@@ -77,7 +89,10 @@ async def _get_runtime_config(
7789
f"(max_total_tokens={max_total_tokens}, page_size={page_size})"
7890
)
7991

80-
# Note: max_running_requests and max_prefill_tokens are NOT available in scheduler_info
92+
# Note: max_running_requests and max_prefill_tokens are NOT available in scheduler_info.
93+
# SGLang separates configuration (server_args) from runtime stats (scheduler_info).
94+
# In contrast, vLLM exposes both config and runtime values through engine config.
95+
# These are config parameters, so they must be retrieved from server_args only.
8196

8297
return runtime_config
8398

components/backends/trtllm/src/dynamo/trtllm/main.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,29 @@ async def init(runtime: DistributedRuntime, config: Config):
281281
# TODO: fix this once we have a better way to get total_kv_blocks
282282
runtime_config = ModelRuntimeConfig()
283283

284+
# Set values from config that are available immediately
285+
# Note: We populate max_num_seqs and max_num_batched_tokens from config
286+
# to ensure Prometheus metrics are available even without engine stats
287+
288+
# Naming clarification:
289+
# - In vLLM: max_num_seqs = maximum concurrent requests (this is an unusual name due to vLLM's historic reasons)
290+
# - In TensorRT-LLM: max_batch_size = maximum concurrent requests (clearer name)
291+
# Both parameters control the same thing: how many requests can be processed simultaneously
292+
runtime_config.max_num_seqs = config.max_batch_size
293+
runtime_config.max_num_batched_tokens = config.max_num_tokens
284294
runtime_config.reasoning_parser = config.reasoning_parser
285295
runtime_config.tool_call_parser = config.tool_call_parser
286296

297+
logging.info(f"Set runtime config max_num_seqs: {runtime_config.max_num_seqs}")
298+
logging.info(
299+
f"Set runtime config max_num_batched_tokens: {runtime_config.max_num_batched_tokens}"
300+
)
301+
302+
# The get_engine_runtime_config function exists but is not called here due to:
303+
# 1. get_stats_async requires active requests to work properly
304+
# 2. We need runtime config during registration, before any requests are made
305+
# 3. total_kv_blocks would ideally come from engine stats but is not critical for basic operation
306+
287307
# publisher will be set later if publishing is enabled.
288308
handler_config = RequestHandlerConfig(
289309
component=component,

0 commit comments

Comments
 (0)