Skip to content

Commit 9e929ab

Browse files
committed
more fixes
1 parent 945fb9b commit 9e929ab

File tree

5 files changed

+162
-3
lines changed

5 files changed

+162
-3
lines changed

components/backends/trtllm/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ For comprehensive instructions on multinode serving, see the [multinode-examples
185185

186186
### Speculative Decoding
187187
- **[Llama 4 Maverick Instruct + Eagle Speculative Decoding](./llama4_plus_eagle.md)**
188+
- **[Async Speculative Decoding](./async_spec_dec.md)**
188189

189190
### Kubernetes Deployment
190191

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
<!--
2+
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
SPDX-License-Identifier: Apache-2.0
4+
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
-->
17+
18+
# Async Speculative Decoding
19+
20+
This guide demonstrates how to run Draft-Target Model (DTM) speculative decoding asynchronously in Dynamo, where the draft model and target model run as separate Dynamo workers with the TRT-LLM backend.
21+
22+
## Setup
23+
24+
Follow the [Quickstart setup](./README.md#quick-start) instructions. Then, inside the container, run the following example:
25+
26+
```
27+
cd $DYNAMO_HOME/components/backends/trtllm
28+
./launch/spec_dec.sh
29+
```
30+
31+
To scale up the number of drafters:
32+
33+
```
34+
cd $DYNAMO_HOME/components/backends/trtllm
35+
export NUM_DRAFTERS=2
36+
export DRAFTER_CUDA_VISIBLE_DEVICES:-"1,2"
37+
./launch/spec_dec.sh
38+
```

components/backends/trtllm/launch/spec_dec.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ export VERIFIER_CUDA_VISIBLE_DEVICES=${VERIFIER_CUDA_VISIBLE_DEVICES:-"0"}
1111

1212
# Drafter variables
1313
export NUM_DRAFTERS=${NUM_DRAFTERS:-1}
14-
export DRAFTER_MODEL_PATH=${MODEL_PATH:-"meta-llama/Meta-Llama-3.2-1B-Instruct"}
14+
export DRAFTER_MODEL_PATH=${DRAFTER_MODEL_PATH:-"meta-llama/Meta-Llama-3.2-1B-Instruct"}
1515
export DRAFTER_MODEL_NAME=${DRAFTER_MODEL_NAME:-"meta-llama/Meta-Llama-3.2-1B-Instruct"}
1616
export DRAFTER_ENGINE_ARGS=${DRAFTER_ENGINE_ARGS:-"engine_configs/drafter.yaml"}
1717
export DRAFTER_CUDA_VISIBLE_DEVICES=${DRAFTER_CUDA_VISIBLE_DEVICES:-"1"}
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import asyncio
17+
import logging
18+
import os
19+
from typing import Dict, List
20+
21+
from tensorrt_llm._torch.speculative.external_api import APIDrafter
22+
23+
from dynamo.runtime import DistributedRuntime
24+
from dynamo.runtime.logging import configure_dynamo_logging
25+
26+
configure_dynamo_logging()
27+
# TODO: remove this
28+
logging.getLogger().setLevel(logging.WARNING)
29+
30+
31+
class DynamoAPIDrafter(APIDrafter):
32+
"""
33+
Custom Dynamo drafter to support internal Dynamo endpoints instead of only HTTP endpoints.
34+
"""
35+
36+
def __init__(self, spec_config, runtime: DistributedRuntime):
37+
super().__init__(spec_config)
38+
self.client = None
39+
self.max_draft_len = spec_config.max_draft_len
40+
# TODO: allow custom etcd connection info to be set in the spec_config
41+
self.connection_info: Dict[str, str] = {}
42+
43+
async def _create_client(self):
44+
try:
45+
# parse endpoint
46+
endpoint_path = self.endpoint.replace("dyn://", "")
47+
parts = endpoint_path.split(".")
48+
if len(parts) != 3:
49+
raise ValueError(
50+
f"Invalid Dynamo endpoint format. Received: {self.endpoint}, but expected: dyn://namespace.component.endpoint"
51+
)
52+
namespace, component, endpoint = parts
53+
54+
# create minimal runtime for client access only
55+
etcd_endpoints = self.connection_info.get(
56+
"etcd_endpoints", "localhost:2379"
57+
)
58+
os.environ.setdefault("ETCD_ENDPOINTS", etcd_endpoints)
59+
loop = asyncio.get_event_loop()
60+
self.runtime = DistributedRuntime(loop, False)
61+
62+
self.client = (
63+
await self.runtime.namespace(namespace)
64+
.component(component)
65+
.endpoint(endpoint)
66+
.client()
67+
)
68+
except Exception as e:
69+
logging.error(
70+
f"Failed to create client for Dynamo endpoint: {self.endpoint} with error: {e}"
71+
)
72+
raise e
73+
74+
async def get_draft_tokens(
75+
self,
76+
prefix: list[int],
77+
request_id: int,
78+
end_id: int,
79+
max_sequence_length: int,
80+
) -> List[int]:
81+
print(f"VERIFIER: {prefix}\n")
82+
if self.endpoint.startswith("dyn://"):
83+
request_data = {
84+
"token_ids": prefix,
85+
"sampling_options": {},
86+
"stop_conditions": {
87+
"max_tokens": self.max_draft_len,
88+
},
89+
}
90+
91+
if self.client is None:
92+
await self._create_client()
93+
94+
draft_tokens = List[int] = []
95+
try:
96+
if self.client is None:
97+
logging.error(
98+
f"Failed to create client for Dynamo endpoint: {self.endpoint}"
99+
)
100+
return []
101+
response = await self.client.round_robin(request_data)
102+
103+
async for chunk in response:
104+
chunk_data = chunk.data()
105+
if chunk_data.get("finish_reason"):
106+
break
107+
draft_tokens.extend(chunk_data.get("token_ids", []))
108+
if len(draft_tokens) >= self.max_draft_len:
109+
break
110+
print(f"DRAFTER: {draft_tokens}\n")
111+
return draft_tokens[: self.max_draft_len]
112+
except Exception as e:
113+
logging.error(
114+
f"Failed to get draft tokens for Dynamo endpoint: {self.endpoint} with error: {e}"
115+
)
116+
raise e
117+
else:
118+
raise ValueError(
119+
f"Invalid Dynamo endpoint format. Received: {self.endpoint}, but expected: dyn://namespace.component.endpoint"
120+
)

components/backends/trtllm/src/dynamo/trtllm/utils/trtllm_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,14 @@ def is_drafter(config):
9999
"""
100100
Check if the current worker is a drafter worker.
101101
"""
102-
return config.component == "drafter"
102+
return config.spec_dec_mode == "drafter"
103103

104104

105105
def is_verifier(config):
106106
"""
107107
Check if the current worker is a verifier worker.
108108
"""
109-
return config.component == "verifier"
109+
return config.spec_dec_mode == "verifier"
110110

111111

112112
def parse_endpoint(endpoint: str) -> tuple[str, str, str]:

0 commit comments

Comments
 (0)