Skip to content

Commit 1b1e089

Browse files
authored
feat: Enable dynamo-run out=trtllm (#1223)
1 parent fc31a51 commit 1b1e089

File tree

10 files changed

+148
-30
lines changed

10 files changed

+148
-30
lines changed

docs/guides/dynamo_run.md

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
* [llama.cpp](#llamacpp)
1313
* [Sglang](#sglang)
1414
* [Vllm](#vllm)
15-
* [TensorRT-LLM](#tensorrt-llm-engine)
15+
* [TensorRT-LLM](#trtllm)
1616
* [Echo Engines](#echo-engines)
1717
* [Writing your own engine in Python](#writing-your-own-engine-in-python)
1818
* [Batch mode](#batch-mode)
@@ -437,10 +437,13 @@ Startup can be slow so you may want to `export DYN_LOG=debug` to see progress.
437437
438438
Shutdown: `ray stop`
439439
440-
#### TensorRT-LLM engine
440+
#### trtllm
441441
442-
To run a TRT-LLM model with dynamo-run we have included a python based [async engine] (https://github.com/ai-dynamo/dynamo/blob/main/examples/tensorrt_llm/engines/agg_engine.py).
443-
To configure the TensorRT-LLM async engine please see [llm_api_config.yaml](https://github.com/ai-dynamo/dynamo/blob/main/examples/tensorrt_llm/configs/llm_api_config.yaml). The file defines the options that need to be passed to the LLM engine. Follow the steps below to serve trtllm on dynamo run.
442+
Using [TensorRT-LLM's LLM API](https://nvidia.github.io/TensorRT-LLM/llm-api/), a high-level Python API.
443+
444+
You can use `--extra-engine-args` to pass extra arguments to LLM API engine.
445+
446+
The trtllm engine requires requires [etcd](https://etcd.io/) and [nats](https://nats.io/) with jetstream (`nats-server -js`) to be running.
444447
445448
##### Step 1: Build the environment
446449
@@ -454,7 +457,7 @@ See instructions [here](https://github.com/ai-dynamo/dynamo/blob/main/examples/t
454457
455458
Execute the following to load the TensorRT-LLM model specified in the configuration.
456459
```
457-
dynamo run out=pystr:/workspace/examples/tensorrt_llm/engines/trtllm_engine.py -- --engine_args /workspace/examples/tensorrt_llm/configs/llm_api_config.yaml
460+
dynamo-run in=http out=trtllm TinyLlama/TinyLlama-1.1B-Chat-v1.0
458461
```
459462
460463
#### Echo Engines
@@ -529,6 +532,20 @@ Pass it like this:
529532
```
530533
dynamo-run out=sglang ~/llms/Llama-3.2-3B-Instruct --extra-engine-args sglang_extra.json
531534
```
535+
536+
The tensorrtllm backend also support passing any argument the engine accepts. However, in this case config should be a yaml file.
537+
538+
```
539+
backend: pytorch
540+
kv_cache_config:
541+
event_buffer_max_size: 1024
542+
```
543+
544+
Pass it like this:
545+
```
546+
dynamo-run in=http out=trtllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 --extra-engine-args trtllm_extra.yaml
547+
```
548+
532549
### Writing your own engine in Python
533550
534551
Note: This section replaces "bring-your-own-engine".

launch/dynamo-run/src/lib.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,40 @@ pub async fn run(
223223
}));
224224
EngineConfig::Dynamic
225225
}
226+
Output::Trtllm => {
227+
if flags.base_gpu_id != 0 {
228+
anyhow::bail!("TRTLLM does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead.");
229+
}
230+
231+
// If `in=dyn` we want the trtllm subprocess to listen on that endpoint.
232+
// If not, then the endpoint isn't exposed so we invent an internal one.
233+
let endpoint = match &in_opt {
234+
Input::Endpoint(path) => path.parse()?,
235+
_ => INTERNAL_ENDPOINT.parse()?,
236+
};
237+
238+
let (py_script, child) = match subprocess::start(
239+
subprocess::trtllm::PY,
240+
&local_model,
241+
&endpoint,
242+
flags.clone(),
243+
None, // multi-node config. trtlllm uses `mpi`, see guide
244+
)
245+
.await
246+
{
247+
Ok(x) => x,
248+
Err(err) => {
249+
anyhow::bail!("Failed starting trtllm sub-process: {err}");
250+
}
251+
};
252+
let cancel_token = cancel_token.clone();
253+
254+
// Sub-process cleanup
255+
extra = Some(Box::pin(async move {
256+
stopper(cancel_token, child, py_script).await;
257+
}));
258+
EngineConfig::Dynamic
259+
}
226260

227261
#[cfg(feature = "llamacpp")]
228262
Output::LlamaCpp => {

launch/dynamo-run/src/opt.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ pub enum Output {
101101
/// Run inference using sglang
102102
SgLang,
103103

104+
/// Run inference using trtllm
105+
Trtllm,
106+
104107
// Start vllm in a sub-process connecting via nats
105108
// Sugar for `python vllm_inc.py --endpoint <thing> --model <thing>`
106109
Vllm,
@@ -125,6 +128,7 @@ impl TryFrom<&str> for Output {
125128
"llamacpp" | "llama_cpp" => Ok(Output::LlamaCpp),
126129

127130
"sglang" => Ok(Output::SgLang),
131+
"trtllm" => Ok(Output::Trtllm),
128132
"vllm" => Ok(Output::Vllm),
129133

130134
"echo_full" => Ok(Output::EchoFull),
@@ -164,6 +168,7 @@ impl fmt::Display for Output {
164168
Output::LlamaCpp => "llamacpp",
165169

166170
Output::SgLang => "sglang",
171+
Output::Trtllm => "trtllm",
167172
Output::Vllm => "vllm",
168173

169174
Output::EchoFull => "echo_full",
@@ -210,6 +215,7 @@ impl Output {
210215
}
211216

212217
out.push(Output::SgLang.to_string());
218+
out.push(Output::Trtllm.to_string());
213219
out.push(Output::Vllm.to_string());
214220

215221
#[cfg(feature = "python")]

launch/dynamo-run/src/subprocess.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use dynamo_llm::local_model::LocalModel;
1515
use dynamo_runtime::protocols::Endpoint as EndpointId;
1616

1717
pub mod sglang;
18+
pub mod trtllm;
1819
pub mod vllm;
1920

2021
pub async fn start(

launch/dynamo-run/src/subprocess/trtllm_inc.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
# TODO:
5-
# - Add event and metrics publishers
6-
# - Support default dynamo-run out=trtllm launch
75
# - Support disaggregated serving
6+
# - Update examples to use this engine.
87
#
8+
# `dynamo-run out=trtllm` runs this script
99
# Can be used standalone: `python3 trtllm_inc.py` - lots of optional cmd line params
1010

1111
import argparse
1212
import asyncio
1313
import logging
1414
import sys
15+
import warnings
1516
from typing import Optional
1617

1718
import uvloop
@@ -20,10 +21,13 @@
2021
from tensorrt_llm import SamplingParams
2122
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
2223
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
23-
from trtllm.engine import get_llm_engine
24-
from trtllm.publishers import Publishers
2524

26-
from dynamo.llm import ModelType, register_llm
25+
from dynamo.llm import (
26+
ModelType,
27+
get_tensorrtllm_engine,
28+
get_tensorrtllm_publisher,
29+
register_llm,
30+
)
2731
from dynamo.runtime import DistributedRuntime, dynamo_worker
2832

2933
# Only used if you run it manually from the command line
@@ -44,7 +48,7 @@ class Config:
4448
component: str
4549
endpoint: str
4650
model_path: str
47-
model_name: Optional[str]
51+
model_name: Optional[str] = None
4852
tensor_parallel_size: int
4953
kv_block_size: int
5054
extra_engine_args: str
@@ -65,7 +69,9 @@ def __init__(self, component, engine, default_sampling_params, publishers):
6569

6670
async def generate(self, request):
6771
# Check if there is an error in the publishers error queue
68-
publishers_error = self.publishers.check_error_queue()
72+
publishers_error = (
73+
self.publishers.check_error_queue() if self.publishers else None
74+
)
6975
if publishers_error:
7076
raise publishers_error
7177

@@ -90,7 +96,7 @@ async def generate(self, request):
9096
# TRTLLM engine needs to start generating tokens first before stats
9197
# can be retrieved.
9298
if self.first_generation and self.publishers:
93-
self.publishers.start_publish_threads()
99+
self.publishers.start()
94100
self.first_generation = False
95101

96102
if res.finished:
@@ -137,6 +143,7 @@ async def init(runtime: DistributedRuntime, config: Config):
137143
"disable_log_stats": False,
138144
}
139145
if config.extra_engine_args != "":
146+
# TODO: Support extra engine args from json file as well.
140147
arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args)
141148
if config.publish_events_and_metrics:
142149
# 'event_buffer_max_size' is required to enable TRTLLM to publish kv cache events.
@@ -168,34 +175,33 @@ async def init(runtime: DistributedRuntime, config: Config):
168175
default_sampling_params._setup(tokenizer)
169176
default_sampling_params.stop = None
170177

171-
async with get_llm_engine(engine_args) as engine:
178+
async with get_tensorrtllm_engine(engine_args) as engine:
172179
endpoint = component.endpoint(config.endpoint)
173180
await register_llm(
174181
ModelType.Backend, endpoint, config.model_path, config.model_name
175182
)
176183

177-
publishers = None
178184
if config.publish_events_and_metrics:
185+
# Initialize and pass in the publishers to the request handler to
186+
# publish events and metrics.
179187
kv_listener = runtime.namespace(config.namespace).component(
180188
config.component
181189
)
182-
publishers = Publishers(
190+
async with get_tensorrtllm_publisher(
183191
component,
184192
engine,
185193
kv_listener,
186194
int(endpoint.lease_id()),
187195
config.kv_block_size,
188-
)
189-
190-
handler = RequestHandler(component, engine, default_sampling_params, publishers)
191-
192-
try:
193-
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
194-
# after the lease is revoked
196+
) as publisher:
197+
handler = RequestHandler(
198+
component, engine, default_sampling_params, publisher
199+
)
200+
await endpoint.serve_endpoint(handler.generate)
201+
else:
202+
# No publishers, so just pass in None to the request handler.
203+
handler = RequestHandler(component, engine, default_sampling_params, None)
195204
await endpoint.serve_endpoint(handler.generate)
196-
finally:
197-
if publishers:
198-
await publishers.cleanup()
199205

200206

201207
def cmd_line_args():
@@ -228,6 +234,12 @@ def cmd_line_args():
228234
parser.add_argument(
229235
"--kv-block-size", type=int, default=32, help="Size of a KV cache block."
230236
)
237+
parser.add_argument(
238+
"--context-length",
239+
type=int,
240+
default=None,
241+
help="This argument is not used by TRTLLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
242+
)
231243
parser.add_argument(
232244
"--extra-engine-args",
233245
type=str,
@@ -241,6 +253,12 @@ def cmd_line_args():
241253
)
242254
args = parser.parse_args()
243255

256+
if args.context_length is not None:
257+
warnings.warn(
258+
"--context-length is accepted for compatibility but will be ignored for TensorRT-LLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
259+
UserWarning,
260+
)
261+
244262
config = Config()
245263
config.model_path = args.model_path
246264
if args.model_name:

lib/bindings/python/src/dynamo/llm/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,13 @@
3333
from dynamo._core import ModelType as ModelType
3434
from dynamo._core import OverlapScores as OverlapScores
3535
from dynamo._core import register_llm as register_llm
36+
37+
try:
38+
from dynamo.llm.tensorrtllm import ( # noqa: F401
39+
get_llm_engine as get_tensorrtllm_engine,
40+
)
41+
from dynamo.llm.tensorrtllm import ( # noqa: F401
42+
get_publisher as get_tensorrtllm_publisher,
43+
)
44+
except ImportError:
45+
pass # TensorRTLLM is not enabled by default
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
from .engine import get_llm_engine # noqa: F401
18+
from .publisher import get_publisher # noqa: F401

launch/dynamo-run/src/subprocess/trtllm/publishers.py renamed to lib/bindings/python/src/dynamo/llm/tensorrtllm/publisher.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import threading
88
import traceback
99
import weakref
10+
from contextlib import asynccontextmanager
1011
from queue import Queue
1112
from typing import Callable, Optional, Union
1213

@@ -80,7 +81,7 @@ def stop(self):
8081
self._current_future.cancel()
8182

8283

83-
class Publishers:
84+
class Publisher:
8485
"""
8586
A class to retrieve stats and kv cache events from TRTLLM engine and publish them to the metrics and events publishers.
8687
"""
@@ -102,7 +103,6 @@ def __init__(self, component, engine, kv_listener, worker_id, kv_block_size):
102103
self.partial_block_hashes = set()
103104
self.error_queue: Queue = Queue()
104105
self._stop_event = threading.Event()
105-
self._setup()
106106

107107
async def _create_metrics_publisher_endpoint(self):
108108
logging.debug("Creating metrics publisher endpoint")
@@ -111,7 +111,7 @@ async def _create_metrics_publisher_endpoint(self):
111111
return
112112
await self.metrics_publisher.create_endpoint(self.component)
113113

114-
def _setup(self):
114+
def initialize(self):
115115
# Setup the metrics publisher
116116
self.metrics_publisher = KvMetricsPublisher()
117117
self._init_publish_metrics_thread()
@@ -298,7 +298,7 @@ async def _publish_kv_cache_events_task(self):
298298
self.kv_event_publisher.publish_removed(event_id, block_hashes)
299299
return True
300300

301-
def start_publish_threads(self):
301+
def start(self):
302302
if (
303303
self.publish_kv_cache_events_thread
304304
and not self.publish_kv_cache_events_thread.is_alive()
@@ -342,3 +342,16 @@ async def cleanup(self):
342342
self.publish_kv_cache_events_thread.join(timeout=cleanup_timeout)
343343
if self.publish_kv_cache_events_thread.is_alive():
344344
logging.warning("KV cache events thread did not stop within timeout")
345+
346+
347+
@asynccontextmanager
348+
async def get_publisher(component, engine, kv_listener, worker_id, kv_block_size):
349+
publisher = Publisher(component, engine, kv_listener, worker_id, kv_block_size)
350+
try:
351+
publisher.initialize()
352+
yield publisher
353+
except Exception as e:
354+
logging.error(f"Error in engine context: {e}")
355+
raise
356+
finally:
357+
await publisher.cleanup()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ addopts = [
141141
"--ignore-glob=*model.py",
142142
"--ignore-glob=*_inc.py",
143143
"--ignore-glob=deploy/cloud/api-store/*",
144+
"--ignore-glob=*/llm/tensorrtllm*",
144145
# FIXME: Get relative/generic blob paths to work here
145146
]
146147
xfail_strict = true

0 commit comments

Comments
 (0)