Skip to content

Commit 536ad29

Browse files
committed
feat: add logits processor support for trtllm backend
1 parent 6cf96e0 commit 536ad29

File tree

8 files changed

+192
-2
lines changed

8 files changed

+192
-2
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from .adapter import TrtllmDynamoLogitsAdapter, create_trtllm_adapters
5+
6+
__all__ = ["TrtllmDynamoLogitsAdapter", "create_trtllm_adapters"]
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import logging
5+
from typing import List, Optional
6+
7+
import torch
8+
from tensorrt_llm.sampling_params import LogitsProcessor
9+
10+
from dynamo.logits_processing import BaseLogitsProcessor
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class TrtllmDynamoLogitsAdapter(LogitsProcessor):
16+
"""
17+
Adapter that wraps Dynamo BaseLogitsProcessor instances to work with TensorRT-LLM's logits processor interface.
18+
19+
Inherits from tensorrt_llm.LogitsProcessor and implements the required interface:
20+
__call__(self, req_ids: int, logits: torch.Tensor, ids: List[List[int]], stream_ptr, client_id: Optional[int])
21+
22+
This adapter maintains per-request state and converts between the interfaces.
23+
"""
24+
25+
def __init__(self, processor: BaseLogitsProcessor):
26+
super().__init__()
27+
self.processor = processor
28+
29+
def __call__(
30+
self,
31+
req_ids: int,
32+
logits: torch.Tensor,
33+
ids: List[List[int]],
34+
stream_ptr,
35+
client_id: Optional[int] = None,
36+
):
37+
"""
38+
TensorRT-LLM logits processor interface.
39+
40+
Args:
41+
req_ids: Request identifier
42+
logits: Logits tensor for current step
43+
ids: List of token sequences (batch of sequences)
44+
stream_ptr: CUDA stream pointer
45+
client_id: Optional client identifier
46+
47+
Returns:
48+
Modified logits tensor (in-place modification expected by TRT-LLM)
49+
"""
50+
print(f"Shapes: logits {logits.shape}, ids {ids}")
51+
try:
52+
for ids_req, logits_req in zip(ids, logits):
53+
if logits_req.shape[0] != 1:
54+
raise ValueError(
55+
"Logits processing with beam width > 1 is not supported"
56+
)
57+
# Remove dimension 0 from logits_req
58+
modified_logits = self.processor(ids_req, logits_req.reshape(-1))
59+
60+
# TRT-LLM expects in-place modification
61+
logits.copy_(modified_logits)
62+
63+
except Exception as e:
64+
logger.error(f"Error in logits processor for request {req_ids}: {e}")
65+
# Don't modify logits on error
66+
67+
# TRT-LLM expects void return (in-place modification)
68+
69+
70+
def create_trtllm_adapters(
71+
processors: List[BaseLogitsProcessor],
72+
) -> List[TrtllmDynamoLogitsAdapter]:
73+
"""
74+
Create TensorRT-LLM compatible adapters from Dynamo logits processors.
75+
76+
Args:
77+
processors: List of Dynamo BaseLogitsProcessor instances
78+
79+
Returns:
80+
List of TensorRT-LLM compatible logits processor adapters
81+
"""
82+
adapters = []
83+
for processor in processors:
84+
adapter = TrtllmDynamoLogitsAdapter(processor)
85+
adapters.append(adapter)
86+
return adapters

components/backends/trtllm/src/dynamo/trtllm/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ async def init(runtime: DistributedRuntime, config: Config):
162162
"pipeline_parallel_size": config.pipeline_parallel_size,
163163
"moe_expert_parallel_size": config.expert_parallel_size,
164164
"backend": "pytorch",
165-
"skip_tokenizer_init": True,
165+
"skip_tokenizer_init": False,
166166
"build_config": build_config,
167167
"kv_cache_config": kv_cache_config,
168168
"gpus_per_node": gpus_per_node,

components/backends/trtllm/src/dynamo/trtllm/request_handlers/handler_base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
from tensorrt_llm import SamplingParams
2222
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
2323

24+
from dynamo.logits_processing.examples import HelloWorldLogitsProcessor
2425
from dynamo.runtime.logging import configure_dynamo_logging
2526
from dynamo.trtllm.engine import TensorRTLLMEngine
27+
from dynamo.trtllm.logits_processing import create_trtllm_adapters
2628
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
2729
from dynamo.trtllm.publisher import Publisher
2830
from dynamo.trtllm.utils.disagg_utils import (
@@ -168,6 +170,11 @@ async def generate_locally(self, request: dict):
168170
request_id = request.get("id") or request.get("request_id", "unknown-id")
169171
model_name = request.get("model", "unknown_model")
170172

173+
# TODO: Just for testing. Hardcoding the hello world processor.
174+
processors = [HelloWorldLogitsProcessor(self.engine.llm.tokenizer)]
175+
adapters = create_trtllm_adapters(processors)
176+
sampling_params.logits_processor = adapters
177+
171178
# NEW: Updated engine call to include multimodal data
172179
async for res in self.engine.llm.generate_async(
173180
inputs=processed_input, # Use the correctly extracted inputs

lib/bindings/python/src/dynamo/logits_processing/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
logits processors must implement.
99
"""
1010

11-
from typing import Protocol, Sequence
11+
from typing import Protocol, Sequence, runtime_checkable
1212

1313
import torch
1414

1515

16+
@runtime_checkable
1617
class BaseLogitsProcessor(Protocol):
1718
"""
1819
Protocol for logits processors in Dynamo.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from .hello_world import HelloWorldLogitsProcessor
5+
from .temperature import TemperatureProcessor
6+
7+
__all__ = ["TemperatureProcessor", "HelloWorldLogitsProcessor"]
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from typing import Sequence
5+
6+
import torch
7+
from transformers import PreTrainedTokenizerBase
8+
9+
from dynamo.logits_processing import BaseLogitsProcessor
10+
11+
RESPONSE = "Hello world!"
12+
13+
14+
class HelloWorldLogitsProcessor(BaseLogitsProcessor):
15+
"""
16+
Sample Logits Processor that always outputs a hardcoded
17+
response (`RESPONSE`), no matter the input
18+
"""
19+
20+
def __init__(self, tokenizer: PreTrainedTokenizerBase):
21+
self.tokenizer = tokenizer
22+
self.token_ids = tokenizer.encode(RESPONSE, add_special_tokens=False)
23+
self.eos_id = tokenizer.eos_token_id
24+
self.state = 0
25+
26+
def __call__(self, input_ids: Sequence[int], scores: torch.Tensor) -> torch.Tensor:
27+
print("Calling logits processor")
28+
mask = torch.full_like(scores, float("-inf"))
29+
30+
if self.state < len(self.token_ids):
31+
token_idx = self.token_ids[self.state]
32+
else:
33+
token_idx = self.eos_id
34+
# Allow only a single token to be output
35+
mask[token_idx] = 0.0
36+
37+
# The `scores` tensor *must* also be modified in-place
38+
scores.add_(mask)
39+
self.state += 1
40+
41+
return scores
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from typing import Sequence
5+
6+
import torch
7+
8+
from dynamo.logits_processing import BaseLogitsProcessor
9+
10+
11+
class TemperatureProcessor(BaseLogitsProcessor):
12+
"""
13+
Example logits processor that applies temperature scaling.
14+
15+
This is a simple demonstration of how to implement a logits processor
16+
that can be used with any Dynamo backend.
17+
"""
18+
19+
def __init__(self, temperature: float = 1.0):
20+
"""
21+
Args:
22+
temperature: Scaling factor. Higher values make distribution more uniform,
23+
lower values make it more peaked. Must be positive.
24+
"""
25+
if temperature <= 0:
26+
raise ValueError("Temperature must be positive")
27+
self.temperature = temperature
28+
29+
def __call__(self, input_ids: Sequence[int], logits: torch.Tensor) -> torch.Tensor:
30+
"""
31+
Apply temperature scaling to logits.
32+
33+
Args:
34+
input_ids: Token IDs generated so far (unused in this simple example)
35+
logits: Raw logits tensor from model
36+
37+
Returns:
38+
Temperature-scaled logits tensor
39+
"""
40+
if self.temperature == 1.0:
41+
return logits
42+
return logits / self.temperature

0 commit comments

Comments
 (0)