Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions components/backends/trtllm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ For comprehensive instructions on multinode serving, see the [multinode-examples

### Speculative Decoding
- **[Llama 4 Maverick Instruct + Eagle Speculative Decoding](./llama4_plus_eagle.md)**
- **[Async Speculative Decoding](./async_spec_dec.md)**

### Kubernetes Deployment

Expand Down
54 changes: 54 additions & 0 deletions components/backends/trtllm/async_spec_dec.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

# Async Speculative Decoding

This guide demonstrates how to run Draft-Target Model (DTM) speculative decoding asynchronously in Dynamo, where the draft model and target model run as separate Dynamo workers with the TRT-LLM backend.

## Setup

Follow the [Quickstart setup](./README.md#quick-start) instructions. Then, inside the container, run the following example:

```
cd $DYNAMO_HOME/components/backends/trtllm
./launch/spec_dec.sh
```

To scale up the number of drafters:

```
cd $DYNAMO_HOME/components/backends/trtllm
export NUM_DRAFTERS=2
export DRAFTER_CUDA_VISIBLE_DEVICES:-"1,2"
./launch/spec_dec.sh
```

## Parallel Speculative Decoding

To enable parallel speculative decoding, add the ```--parallel-spec-dec``` to the verifier:

```
# run verifier worker with speculative decoding
CUDA_VISIBLE_DEVICES=$VERIFIER_CUDA_VISIBLE_DEVICES \
python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$VERIFIER_ENGINE_ARGS" \
--spec-dec-mode "verifier" & \
--parallel-spec-dec
VERIFIER_PID=$!
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you describe the request flow in this setup as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be a mermaid diagram?

32 changes: 32 additions & 0 deletions components/backends/trtllm/engine_configs/drafter.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 8192
max_batch_size: 16
trust_remote_code: true
backend: pytorch
enable_chunked_prefill: true

kv_cache_config:
free_gpu_memory_fraction: 0.95

# NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603
# NOTE: overlap_scheduler enabled by default since this commit and changed
# config field from 'enable_overlap_scheduler' to 'disable_overlap_scheduler':
# https://github.com/NVIDIA/TensorRT-LLM/commit/b4e5df0ee0024eda3eeb83a6ba822245a30ab428
cuda_graph_config:
max_batch_size: 16
36 changes: 36 additions & 0 deletions components/backends/trtllm/engine_configs/verifier.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
backend: pytorch
tensor_parallel_size: 4
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 1024
max_batch_size: 16
trust_remote_code: true
enable_chunked_prefill: true
# External API does not support overlap scheduler
disable_overlap_scheduler: true

drafter_config:
drafter_endpoint: dyn://dynamo.tensorrt_llm.generate_draft
max_draft_len: 10

kv_cache_config:
enable_block_reuse: false

next_client: dyn://dynamo.tensorrt_llm.generate_draft

cuda_graph_config:
max_batch_size: 16
76 changes: 76 additions & 0 deletions components/backends/trtllm/launch/spec_dec.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Environment variables with defaults
# Verifier variables
export MODEL_PATH=${MODEL_PATH:-"meta-llama/Meta-Llama-3.1-8B-Instruct"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"meta-llama/Meta-Llama-3.1-8B-Instruct"}
export VERIFIER_ENGINE_ARGS=${VERIFIER_ENGINE_ARGS:-"engine_configs/verifier.yaml"}
export VERIFIER_CUDA_VISIBLE_DEVICES=${VERIFIER_CUDA_VISIBLE_DEVICES:-"0"}

# Drafter variables
export NUM_DRAFTERS=${NUM_DRAFTERS:-1}
export DRAFTER_MODEL_PATH=${DRAFTER_MODEL_PATH:-"meta-llama/Meta-Llama-3.2-1B-Instruct"}
export DRAFTER_MODEL_NAME=${DRAFTER_MODEL_NAME:-"meta-llama/Meta-Llama-3.2-1B-Instruct"}
export DRAFTER_ENGINE_ARGS=${DRAFTER_ENGINE_ARGS:-"engine_configs/drafter.yaml"}
export DRAFTER_CUDA_VISIBLE_DEVICES=${DRAFTER_CUDA_VISIBLE_DEVICES:-"1"}

declare -a DRAFTER_PIDS=()

# Check enough GPUs for drafters
IFS=',' read -ra CUDA_DEVICES <<< "$DRAFTER_CUDA_VISIBLE_DEVICES"
if [ ${#CUDA_DEVICES[@]} -lt $NUM_DRAFTERS ]; then
echo "Error: Not enough CUDA devices specified for drafters. Need $NUM_DRAFTERS devices, but only ${#CUDA_DEVICES[@]} provided."
exit 1
fi

# Check num drafters >= 1
if [[ $NUM_DRAFTERS -lt 1 ]]; then
echo "Error: NUM_DRAFTERS must be >= 1, got: $NUM_DRAFTERS"
exit 1
fi

# Setup cleanup trap
cleanup() {
echo "Cleaning up background processes..."
kill $DYNAMO_PID $VERIFIER_PID "${DRAFTER_PIDS[@]}" 2>/dev/null || true
wait $DYNAMO_PID $VERIFIER_PID "${DRAFTER_PIDS[@]}" 2>/dev/null || true
echo "Cleanup complete."
}
trap cleanup EXIT INT TERM

# run clear_namespace
python3 utils/clear_namespace.py --namespace dynamo

# run frontend
python3 -m dynamo.frontend --http-port 8000 &
DYNAMO_PID=$!

# run verifier worker with speculative decoding
CUDA_VISIBLE_DEVICES=$VERIFIER_CUDA_VISIBLE_DEVICES \
python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$VERIFIER_ENGINE_ARGS" \
--spec-dec-mode "verifier" &
VERIFIER_PID=$!

# run drafter workers without speculative decoding
start_drafter() {
local cuda_device=$1
CUDA_VISIBLE_DEVICES=$cuda_device \
python3 -m dynamo.trtllm \
--endpoint "dyn://dynamo.tensorrt_llm.generate_draft" \
--model-path "$DRAFTER_MODEL_PATH" \
--served-model-name "$DRAFTER_MODEL_NAME" \
--extra-engine-args "$DRAFTER_ENGINE_ARGS" \
--spec-dec-mode "drafter" &
DRAFTER_PIDS+=($!)
}

for ((i=0; i<$NUM_DRAFTERS; i++)); do
start_drafter ${CUDA_DEVICES[$i]}
done

wait
45 changes: 41 additions & 4 deletions components/backends/trtllm/src/dynamo/trtllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
KvCacheConfig,
SchedulerConfig,
)
from tensorrt_llm.llmapi.llm_args import ExternalAPIConfig, UserProvidedDecodingConfig
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from torch.cuda import device_count
Expand All @@ -30,10 +31,13 @@
RequestHandlerConfig,
RequestHandlerFactory,
)
from dynamo.trtllm.utils.api_drafter import DynamoAPIDrafter
from dynamo.trtllm.utils.trtllm_utils import (
Config,
cmd_line_args,
is_drafter,
is_first_worker,
is_verifier,
parse_endpoint,
)

Expand Down Expand Up @@ -122,6 +126,7 @@ async def init(runtime: DistributedRuntime, config: Config):
dynamic_batch_config=dynamic_batch_config,
)
modality = getattr(config, "modality", None) or "text"
parallel_spec_dec = getattr(config, "parallel_spec_dec", False)
arg_map = {
"model": model_path,
"scheduler_config": scheduler_config,
Expand Down Expand Up @@ -165,6 +170,39 @@ async def init(runtime: DistributedRuntime, config: Config):
)
sys.exit(1)

if is_verifier(config):
# Set speculative_config to use DynamoAPIDrafter in the verifier worker.
try:
drafter_config = arg_map["drafter_config"]
drafter_endpoint = drafter_config["drafter_endpoint"]
if "max_draft_len" in drafter_config:
max_draft_len = drafter_config["max_draft_len"]
else:
max_draft_len = 10
except KeyError:
raise ValueError(
"Verifier worker requires drafter_config to be specified in the engine config with drafter_endpoint."
)
del arg_map["drafter_config"]

drafter_config = ExternalAPIConfig(
endpoint=drafter_endpoint,
max_draft_len=max_draft_len,
)
drafter = DynamoAPIDrafter(spec_config=drafter_config, draft_client=next_client)
spec_config = UserProvidedDecodingConfig(
drafter=drafter,
max_draft_len=max_draft_len,
)
arg_map["speculative_config"] = spec_config

parallel_spec_dec_config = None
if parallel_spec_dec:
parallel_spec_dec_config = {
"pre_verify": True,
"old_draft_tokens": [],
}

logging.info(f"TensorRT-LLM engine args: {arg_map}")
engine_args = arg_map

Expand All @@ -175,7 +213,6 @@ async def init(runtime: DistributedRuntime, config: Config):
default_sampling_params.stop = None
modelType = ModelType.Backend
multimodal_processor = None

if modality == "multimodal":
engine_args["skip_tokenizer_init"] = False
modelType = ModelType.Chat
Expand All @@ -194,9 +231,9 @@ async def init(runtime: DistributedRuntime, config: Config):

async with get_llm_engine(engine_args) as engine:
endpoint = component.endpoint(config.endpoint)

if is_first_worker(config):
if is_first_worker(config) and not is_drafter(config):
# Register the model with the endpoint if only the worker is first in the disaggregation chain.
# Drafter workers should not register.
await register_llm(
modelType,
endpoint,
Expand All @@ -215,8 +252,8 @@ async def init(runtime: DistributedRuntime, config: Config):
disaggregation_strategy=config.disaggregation_strategy,
next_client=next_client,
multimodal_processor=multimodal_processor,
parallel_spec_dec_config=parallel_spec_dec_config,
)

if config.publish_events_and_metrics and is_first_worker(config):
# Initialize and pass in the publisher to the request handler to
# publish events and metrics.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class RequestHandlerConfig:
multimodal_processor: Optional[
MultimodalRequestProcessor
] = None # for multimodal support
parallel_spec_dec_config: Optional[
dict
] = None # for parallel speculative decoding support


class HandlerBase:
Expand All @@ -76,6 +79,7 @@ def __init__(self, config: RequestHandlerConfig):
self.disaggregation_strategy = config.disaggregation_strategy
self.next_client = config.next_client
self.multimodal_processor = config.multimodal_processor
self.parallel_spec_dec_config = config.parallel_spec_dec_config
self.first_generation = True

def check_error(self, result: dict):
Expand Down Expand Up @@ -174,6 +178,7 @@ async def generate_locally(self, request: dict):
sampling_params=sampling_params,
disaggregated_params=disaggregated_params,
streaming=streaming,
parallel_spec_dec_params=self.parallel_spec_dec_config,
):
# TRTLLM engine needs to start generating tokens first before stats
# can be retrieved.
Expand Down
Loading