Skip to content

Commit d0a6363

Browse files
feat: add RuntimeConfig to ModelEntry (#2311)
Co-authored-by: Yan Ru Pei <yanrpei@gmail.com>
1 parent b74b887 commit d0a6363

File tree

16 files changed

+469
-35
lines changed

16 files changed

+469
-35
lines changed

components/backends/sglang/src/dynamo/sglang/worker/main.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
from sglang.srt.server_args import ServerArgs
1717
from sglang.srt.utils import get_ip, get_zmq_socket
1818

19+
from dynamo._core import Endpoint
1920
from dynamo.llm import (
2021
ForwardPassMetrics,
2122
KvStats,
23+
ModelRuntimeConfig,
2224
ModelType,
2325
WorkerMetricsPublisher,
2426
WorkerStats,
@@ -334,13 +336,8 @@ async def init(
334336
await component.create_service()
335337

336338
endpoint = component.endpoint("generate")
337-
await register_llm(
338-
ModelType.Backend,
339-
endpoint,
340-
server_args.model_path,
341-
server_args.served_model_name,
342-
kv_cache_block_size=server_args.page_size,
343-
migration_limit=migration_limit,
339+
await register_llm_with_runtime_config(
340+
engine, endpoint, server_args, migration_limit
344341
)
345342

346343
if server_args.disaggregation_mode != "null":
@@ -372,12 +369,75 @@ async def init(
372369
_ = ZmqKvEventPublisher(component=component, config=zmq_config)
373370

374371
tasks = [endpoint.serve_endpoint(handler.generate)]
375-
376372
tasks.extend(setup_native_endpoints(server_args, component, handler))
377373

378374
await asyncio.gather(*tasks)
379375

380376

377+
async def register_llm_with_runtime_config(
378+
engine: sgl.Engine,
379+
endpoint: Endpoint,
380+
server_args: ServerArgs,
381+
migration_limit: int,
382+
):
383+
"""Register LLM with runtime config"""
384+
runtime_config = await _get_runtime_config(engine)
385+
try:
386+
await register_llm(
387+
ModelType.Backend,
388+
endpoint,
389+
server_args.model_path,
390+
server_args.served_model_name,
391+
kv_cache_block_size=server_args.page_size,
392+
migration_limit=migration_limit,
393+
runtime_config=runtime_config,
394+
)
395+
except Exception as e:
396+
logging.error(f"Failed to register with runtime config: {e}")
397+
return None
398+
399+
400+
async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig]:
401+
"""Get runtime config from SGLang engine"""
402+
try:
403+
# Try to check if the engine has a scheduler attribute with the computed values
404+
if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None:
405+
runtime_config = ModelRuntimeConfig()
406+
407+
# Get max_total_num_tokens from scheduler_info
408+
if "max_total_num_tokens" in engine.scheduler_info:
409+
max_total_tokens = engine.scheduler_info["max_total_num_tokens"]
410+
if max_total_tokens and hasattr(
411+
engine.tokenizer_manager, "server_args"
412+
):
413+
page_size = engine.tokenizer_manager.server_args.page_size
414+
if page_size:
415+
runtime_config.total_kv_blocks = (
416+
max_total_tokens + page_size - 1
417+
) // page_size
418+
logging.info(
419+
f"Got total KV blocks from scheduler: {runtime_config.total_kv_blocks} "
420+
f"(max_total_tokens={max_total_tokens}, page_size={page_size})"
421+
)
422+
423+
# Note: max_running_requests and max_prefill_tokens are NOT available in scheduler_info
424+
# TODO: figure out where they are
425+
426+
return runtime_config
427+
428+
# If scheduler approach doesn't work, log and return None to indicate we'll skip runtime config
429+
logging.warning(
430+
"Could not access runtime config from SGLang engine. "
431+
"The engine may compute these values internally after initialization. "
432+
"Proceeding without runtime config - SGLang will use its internal defaults."
433+
)
434+
return None
435+
436+
except Exception as e:
437+
logging.warning(f"Failed to get runtime config: {e}. Proceeding without it.")
438+
return None
439+
440+
381441
def main():
382442
uvloop.install()
383443
asyncio.run(worker())

components/backends/trtllm/src/dynamo/trtllm/main.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
from torch.cuda import device_count
2121
from transformers import AutoConfig
2222

23-
from dynamo.llm import ModelType, register_llm
23+
from dynamo.llm import ModelRuntimeConfig, ModelType, register_llm
2424
from dynamo.runtime import DistributedRuntime, dynamo_worker
2525
from dynamo.runtime.logging import configure_dynamo_logging
26-
from dynamo.trtllm.engine import get_llm_engine
26+
from dynamo.trtllm.engine import TensorRTLLMEngine, get_llm_engine
2727
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
2828
from dynamo.trtllm.publisher import get_publisher
2929
from dynamo.trtllm.request_handlers.handlers import (
@@ -49,6 +49,39 @@ async def graceful_shutdown(runtime):
4949
logging.info("DistributedRuntime shutdown complete")
5050

5151

52+
async def get_engine_runtime_config(
53+
engine: TensorRTLLMEngine, config: Config
54+
) -> ModelRuntimeConfig:
55+
"""Retrieve runtime configuration from TensorRT-LLM engine."""
56+
runtime_config = ModelRuntimeConfig()
57+
58+
try:
59+
# Extract total_kv_blocks from engine stats
60+
stats = engine.llm.get_stats_async(timeout=5)
61+
stat = await anext(stats)
62+
runtime_config.total_kv_blocks = stat["kvCacheStats"]["maxNumBlocks"]
63+
logging.info(
64+
f"Set runtime config total_kv_blocks: {runtime_config.total_kv_blocks}"
65+
)
66+
67+
# Extract max number of sequences
68+
runtime_config.max_num_seqs = config.max_batch_size
69+
logging.info(f"Set runtime config max_num_seqs: {runtime_config.max_num_seqs}")
70+
71+
# Get max_num_batched_tokens from config
72+
runtime_config.max_num_batched_tokens = config.max_num_tokens
73+
logging.info(
74+
f"Set runtime config max_num_batched_tokens: {runtime_config.max_num_batched_tokens}"
75+
)
76+
77+
return runtime_config
78+
79+
except Exception as e:
80+
logging.error(f"Failed to get runtime config from TensorRT-LLM engine: {e}")
81+
# Return config with default/None values if retrieval fails
82+
return runtime_config
83+
84+
5285
@dynamo_worker(static=False)
5386
async def worker(runtime: DistributedRuntime):
5487
# Set up signal handler for graceful shutdown
@@ -196,14 +229,18 @@ async def init(runtime: DistributedRuntime, config: Config):
196229
endpoint = component.endpoint(config.endpoint)
197230

198231
if is_first_worker(config):
199-
# Register the model with the endpoint if only the worker is first in the disaggregation chain.
232+
# Get runtime configuration from the engine
233+
runtime_config = await get_engine_runtime_config(engine, config)
234+
235+
# Register the model with runtime config
200236
await register_llm(
201237
modelType,
202238
endpoint,
203239
config.model_path,
204240
config.served_model_name,
205241
kv_cache_block_size=config.kv_block_size,
206242
migration_limit=config.migration_limit,
243+
runtime_config=runtime_config, # Add runtime config here
207244
)
208245
# publisher will be set later if publishing is enabled.
209246
handler_config = RequestHandlerConfig(

components/backends/vllm/src/dynamo/vllm/main.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.v1.engine.async_llm import AsyncLLM
1313

1414
from dynamo.llm import (
15+
ModelRuntimeConfig,
1516
ModelType,
1617
ZmqKvEventPublisher,
1718
ZmqKvEventPublisherConfig,
@@ -213,13 +214,25 @@ async def init(runtime: DistributedRuntime, config: Config):
213214
handler.kv_publisher = kv_publisher
214215

215216
if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register
217+
runtime_config = ModelRuntimeConfig()
218+
219+
# make a `collective_rpc` call to get runtime configuration values
220+
logging.info(
221+
"Getting engine runtime configuration metadata from vLLM engine..."
222+
)
223+
runtime_values = get_engine_cache_info(engine_client)
224+
runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
225+
runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
226+
runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
227+
216228
await register_llm(
217229
ModelType.Backend,
218230
generate_endpoint,
219231
config.model,
220232
config.served_model_name,
221233
kv_cache_block_size=config.engine_args.block_size,
222234
migration_limit=config.migration_limit,
235+
runtime_config=runtime_config,
223236
)
224237

225238
try:
@@ -237,6 +250,32 @@ async def init(runtime: DistributedRuntime, config: Config):
237250
handler.cleanup()
238251

239252

253+
def get_engine_cache_info(engine: AsyncLLM):
254+
"""Retrieve cache configuration information from [`AsyncLLM`] engine."""
255+
256+
try:
257+
# Get values directly from vllm_config instead of collective_rpc
258+
cache_values = {
259+
"num_gpu_blocks": engine.vllm_config.cache_config.num_gpu_blocks,
260+
}
261+
262+
scheduler_values = {
263+
"max_num_seqs": engine.vllm_config.scheduler_config.max_num_seqs,
264+
"max_num_batched_tokens": engine.vllm_config.scheduler_config.max_num_batched_tokens,
265+
}
266+
267+
logging.info(f"Cache config values: {cache_values}")
268+
logging.info(f"Scheduler config values: {scheduler_values}")
269+
return {
270+
"num_gpu_blocks": cache_values["num_gpu_blocks"],
271+
"max_num_seqs": scheduler_values["max_num_seqs"],
272+
"max_num_batched_tokens": scheduler_values["max_num_batched_tokens"],
273+
}
274+
except Exception as e:
275+
logging.error(f"Failed to get configuration values from vLLM config: {e}")
276+
raise
277+
278+
240279
def main():
241280
uvloop.run(worker())
242281

components/router/src/main.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
// 2. Update the backend component to produce a config in a standard location.
2121
// 3. Update the KvRouter to read the config from the backend component.
2222

23+
use std::collections::HashMap;
2324
use std::sync::Arc;
2425

2526
use clap::Parser;
@@ -29,7 +30,7 @@ use dynamo_llm::kv_router::{
2930
scheduler::{DefaultWorkerSelector, KvSchedulerError, SchedulingRequest},
3031
KvRouter, WorkerSelector,
3132
};
32-
use dynamo_runtime::component::Instance;
33+
use dynamo_llm::local_model::runtime_config::ModelRuntimeConfig;
3334
use dynamo_runtime::{
3435
logging, pipeline::network::Ingress, DistributedRuntime, Result, Runtime, Worker,
3536
};
@@ -86,7 +87,7 @@ pub struct CustomWorkerSelector(DefaultWorkerSelector);
8687
impl WorkerSelector for CustomWorkerSelector {
8788
fn select_worker(
8889
&self,
89-
workers: &[Instance],
90+
workers: &HashMap<i64, Option<ModelRuntimeConfig>>,
9091
request: &SchedulingRequest,
9192
block_size: u32,
9293
) -> Result<WorkerSelectionResult, KvSchedulerError> {

lib/bindings/python/rust/lib.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ use dynamo_runtime::{
2525
use dynamo_llm::{self as llm_rs};
2626
use dynamo_llm::{entrypoint::RouterConfig, kv_router::KvRouterConfig};
2727

28+
use crate::llm::local_model::ModelRuntimeConfig;
29+
2830
#[pyclass(eq, eq_int)]
2931
#[derive(Clone, Debug, PartialEq)]
3032
pub enum RouterMode {
@@ -82,6 +84,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
8284
m.add_class::<llm::entrypoint::KvRouterConfig>()?;
8385
m.add_class::<llm::kv::WorkerMetricsPublisher>()?;
8486
m.add_class::<llm::model_card::ModelDeploymentCard>()?;
87+
m.add_class::<llm::local_model::ModelRuntimeConfig>()?;
8588
m.add_class::<llm::preprocessor::OAIChatPreprocessor>()?;
8689
m.add_class::<llm::backend::Backend>()?;
8790
m.add_class::<llm::kv::OverlapScores>()?;
@@ -131,7 +134,7 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
131134
}
132135

133136
#[pyfunction]
134-
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, user_data=None))]
137+
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None))]
135138
#[allow(clippy::too_many_arguments)]
136139
fn register_llm<'p>(
137140
py: Python<'p>,
@@ -143,6 +146,7 @@ fn register_llm<'p>(
143146
kv_cache_block_size: Option<u32>,
144147
router_mode: Option<RouterMode>,
145148
migration_limit: u32,
149+
runtime_config: Option<ModelRuntimeConfig>,
146150
user_data: Option<&Bound<'p, PyDict>>,
147151
) -> PyResult<Bound<'p, PyAny>> {
148152
let model_type_obj = match model_type {
@@ -173,6 +177,7 @@ fn register_llm<'p>(
173177
.kv_cache_block_size(kv_cache_block_size)
174178
.router_config(Some(router_config))
175179
.migration_limit(Some(migration_limit))
180+
.runtime_config(runtime_config.unwrap_or_default().inner)
176181
.user_data(user_data_json);
177182
// Download from HF, load the ModelDeploymentCard
178183
let mut local_model = builder.build().await.map_err(to_pyerr)?;

lib/bindings/python/rust/llm.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub mod block_manager;
3131
pub mod disagg_router;
3232
pub mod entrypoint;
3333
pub mod kv;
34+
pub mod local_model;
3435
pub mod model_card;
3536
pub mod nats;
3637
pub mod preprocessor;

lib/bindings/python/rust/llm/entrypoint.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ pub fn make_engine<'p>(
164164
.kv_cache_block_size(args.kv_cache_block_size)
165165
.router_config(args.router_config.clone().map(|rc| rc.into()))
166166
.http_port(args.http_port)
167-
.is_mocker(matches!(args.engine_type, EngineType::Mocker));
167+
.is_mocker(matches!(args.engine_type, EngineType::Mocker))
168+
.extra_engine_args(args.extra_engine_args.clone());
168169
pyo3_async_runtimes::tokio::future_into_py(py, async move {
169170
let local_model = builder.build().await.map_err(to_pyerr)?;
170171
let inner = select_engine(distributed_runtime, args, local_model)

0 commit comments

Comments
 (0)