|
16 | 16 | from sglang.srt.server_args import ServerArgs |
17 | 17 | from sglang.srt.utils import get_ip, get_zmq_socket |
18 | 18 |
|
| 19 | +from dynamo._core import Endpoint |
19 | 20 | from dynamo.llm import ( |
20 | 21 | ForwardPassMetrics, |
21 | 22 | KvStats, |
| 23 | + ModelRuntimeConfig, |
22 | 24 | ModelType, |
23 | 25 | WorkerMetricsPublisher, |
24 | 26 | WorkerStats, |
@@ -334,13 +336,8 @@ async def init( |
334 | 336 | await component.create_service() |
335 | 337 |
|
336 | 338 | endpoint = component.endpoint("generate") |
337 | | - await register_llm( |
338 | | - ModelType.Backend, |
339 | | - endpoint, |
340 | | - server_args.model_path, |
341 | | - server_args.served_model_name, |
342 | | - kv_cache_block_size=server_args.page_size, |
343 | | - migration_limit=migration_limit, |
| 339 | + await register_llm_with_runtime_config( |
| 340 | + engine, endpoint, server_args, migration_limit |
344 | 341 | ) |
345 | 342 |
|
346 | 343 | if server_args.disaggregation_mode != "null": |
@@ -372,12 +369,75 @@ async def init( |
372 | 369 | _ = ZmqKvEventPublisher(component=component, config=zmq_config) |
373 | 370 |
|
374 | 371 | tasks = [endpoint.serve_endpoint(handler.generate)] |
375 | | - |
376 | 372 | tasks.extend(setup_native_endpoints(server_args, component, handler)) |
377 | 373 |
|
378 | 374 | await asyncio.gather(*tasks) |
379 | 375 |
|
380 | 376 |
|
| 377 | +async def register_llm_with_runtime_config( |
| 378 | + engine: sgl.Engine, |
| 379 | + endpoint: Endpoint, |
| 380 | + server_args: ServerArgs, |
| 381 | + migration_limit: int, |
| 382 | +): |
| 383 | + """Register LLM with runtime config""" |
| 384 | + runtime_config = await _get_runtime_config(engine) |
| 385 | + try: |
| 386 | + await register_llm( |
| 387 | + ModelType.Backend, |
| 388 | + endpoint, |
| 389 | + server_args.model_path, |
| 390 | + server_args.served_model_name, |
| 391 | + kv_cache_block_size=server_args.page_size, |
| 392 | + migration_limit=migration_limit, |
| 393 | + runtime_config=runtime_config, |
| 394 | + ) |
| 395 | + except Exception as e: |
| 396 | + logging.error(f"Failed to register with runtime config: {e}") |
| 397 | + return None |
| 398 | + |
| 399 | + |
| 400 | +async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig]: |
| 401 | + """Get runtime config from SGLang engine""" |
| 402 | + try: |
| 403 | + # Try to check if the engine has a scheduler attribute with the computed values |
| 404 | + if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None: |
| 405 | + runtime_config = ModelRuntimeConfig() |
| 406 | + |
| 407 | + # Get max_total_num_tokens from scheduler_info |
| 408 | + if "max_total_num_tokens" in engine.scheduler_info: |
| 409 | + max_total_tokens = engine.scheduler_info["max_total_num_tokens"] |
| 410 | + if max_total_tokens and hasattr( |
| 411 | + engine.tokenizer_manager, "server_args" |
| 412 | + ): |
| 413 | + page_size = engine.tokenizer_manager.server_args.page_size |
| 414 | + if page_size: |
| 415 | + runtime_config.total_kv_blocks = ( |
| 416 | + max_total_tokens + page_size - 1 |
| 417 | + ) // page_size |
| 418 | + logging.info( |
| 419 | + f"Got total KV blocks from scheduler: {runtime_config.total_kv_blocks} " |
| 420 | + f"(max_total_tokens={max_total_tokens}, page_size={page_size})" |
| 421 | + ) |
| 422 | + |
| 423 | + # Note: max_running_requests and max_prefill_tokens are NOT available in scheduler_info |
| 424 | + # TODO: figure out where they are |
| 425 | + |
| 426 | + return runtime_config |
| 427 | + |
| 428 | + # If scheduler approach doesn't work, log and return None to indicate we'll skip runtime config |
| 429 | + logging.warning( |
| 430 | + "Could not access runtime config from SGLang engine. " |
| 431 | + "The engine may compute these values internally after initialization. " |
| 432 | + "Proceeding without runtime config - SGLang will use its internal defaults." |
| 433 | + ) |
| 434 | + return None |
| 435 | + |
| 436 | + except Exception as e: |
| 437 | + logging.warning(f"Failed to get runtime config: {e}. Proceeding without it.") |
| 438 | + return None |
| 439 | + |
| 440 | + |
381 | 441 | def main(): |
382 | 442 | uvloop.install() |
383 | 443 | asyncio.run(worker()) |
|
0 commit comments