-
Notifications
You must be signed in to change notification settings - Fork 686
feat: add logits processor support for trtllm backend #2702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
e5071f4
feat: add logits processor support for trtllm backend
bhuvan002 7478710
fix: use cuda stream provided by trtllm
bhuvan002 553ca88
chore: refactor logits processors to be in place
bhuvan002 9c4a10e
chore: cleanup
bhuvan002 fac26aa
chore: move around imports
bhuvan002 efd9dfa
chore: Fail fast in helloworld processor if no eos token found
bhuvan002 1691e2a
chore: add e2e test
bhuvan002 54d0d84
fix: only init tokenizer when needed for logits processors
bhuvan002 d639382
docs: add usage instructions on logits processors
bhuvan002 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
88 changes: 88 additions & 0 deletions
88
components/backends/trtllm/src/dynamo/trtllm/logits_processing/adapter.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import logging | ||
| from typing import List, Optional | ||
|
|
||
| import torch | ||
| from tensorrt_llm.sampling_params import LogitsProcessor | ||
|
|
||
| from dynamo.logits_processing import BaseLogitsProcessor | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class TrtllmDynamoLogitsAdapter(LogitsProcessor): | ||
| """ | ||
| Adapter that wraps Dynamo BaseLogitsProcessor instances to work with TensorRT-LLM's logits processor interface. | ||
|
|
||
| Inherits from tensorrt_llm.LogitsProcessor and implements the required interface: | ||
| __call__(self, req_ids: int, logits: torch.Tensor, ids: List[List[int]], stream_ptr, client_id: Optional[int]) | ||
|
|
||
| This adapter maintains per-request state and converts between the interfaces. | ||
| """ | ||
bhuvan002 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def __init__(self, processor: BaseLogitsProcessor): | ||
| super().__init__() | ||
| self.processor = processor | ||
|
|
||
| def __call__( | ||
| self, | ||
| req_ids: int, | ||
| logits: torch.Tensor, | ||
| ids: List[List[int]], | ||
| stream_ptr, | ||
bhuvan002 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| client_id: Optional[int] = None, | ||
| ): | ||
| """ | ||
| TensorRT-LLM logits processor interface. | ||
|
|
||
| Args: | ||
| req_ids: Request identifier | ||
| logits: Logits tensor for current step | ||
| ids: List of token sequences (batch of sequences) | ||
| stream_ptr: CUDA stream pointer | ||
| client_id: Optional client identifier | ||
|
|
||
| Returns: | ||
| Modified logits tensor (in-place modification expected by TRT-LLM) | ||
| """ | ||
| stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr) | ||
| try: | ||
| with torch.cuda.stream(stream): | ||
| if logits.shape[0] != 1: | ||
| raise ValueError( | ||
| f"This logits adapter only supports per-request logits processing. " | ||
| f"Received logits with batch size {logits.shape[0]} expected 1" | ||
| ) | ||
| if logits.shape[1] != 1: | ||
| raise ValueError( | ||
| "Logits processing with beam width > 1 is not supported" | ||
| ) | ||
| # Call the processor which modifies the logits in-place | ||
| self.processor(ids[0], logits[0, 0, :]) | ||
|
|
||
bhuvan002 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| except Exception as e: | ||
| logger.error(f"Error in logits processor for request {req_ids}: {e}") | ||
| # Don't modify logits on error | ||
|
|
||
| # TRT-LLM expects void return (in-place modification) | ||
|
|
||
|
|
||
| def create_trtllm_adapters( | ||
| processors: List[BaseLogitsProcessor], | ||
| ) -> List[TrtllmDynamoLogitsAdapter]: | ||
| """ | ||
| Create TensorRT-LLM compatible adapters from Dynamo logits processors. | ||
|
|
||
| Args: | ||
| processors: List of Dynamo BaseLogitsProcessor instances | ||
|
|
||
| Returns: | ||
| List of TensorRT-LLM compatible logits processor adapters | ||
| """ | ||
| adapters = [] | ||
| for processor in processors: | ||
| adapter = TrtllmDynamoLogitsAdapter(processor) | ||
| adapters.append(adapter) | ||
| return adapters | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
7 changes: 7 additions & 0 deletions
7
lib/bindings/python/src/dynamo/logits_processing/examples/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from .hello_world import HelloWorldLogitsProcessor | ||
| from .temperature import TemperatureProcessor | ||
|
|
||
| __all__ = ["TemperatureProcessor", "HelloWorldLogitsProcessor"] |
42 changes: 42 additions & 0 deletions
42
lib/bindings/python/src/dynamo/logits_processing/examples/hello_world.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from typing import Sequence | ||
|
|
||
| import torch | ||
| from transformers import PreTrainedTokenizerBase | ||
|
|
||
| from dynamo.logits_processing import BaseLogitsProcessor | ||
|
|
||
| RESPONSE = "Hello world!" | ||
|
|
||
|
|
||
| class HelloWorldLogitsProcessor(BaseLogitsProcessor): | ||
| """ | ||
| Sample Logits Processor that always outputs a hardcoded | ||
| response (`RESPONSE`), no matter the input | ||
| """ | ||
|
|
||
| def __init__(self, tokenizer: PreTrainedTokenizerBase): | ||
| self.tokenizer = tokenizer | ||
| self.token_ids = tokenizer.encode(RESPONSE, add_special_tokens=False) | ||
| self.eos_id = tokenizer.eos_token_id | ||
| if self.eos_id is None: | ||
| raise ValueError( | ||
| "Tokenizer has no eos_token_id; HelloWorldLogitsProcessor requires one." | ||
| ) | ||
| self.state = 0 | ||
|
|
||
| def __call__(self, input_ids: Sequence[int], scores: torch.Tensor): | ||
| mask = torch.full_like(scores, float("-inf")) | ||
|
|
||
| if self.state < len(self.token_ids): | ||
| token_idx = self.token_ids[self.state] | ||
| else: | ||
| token_idx = self.eos_id | ||
| # Allow only a single token to be output | ||
| mask[token_idx] = 0.0 | ||
|
|
||
| # The `scores` tensor *must* also be modified in-place | ||
| scores.add_(mask) | ||
| self.state += 1 |
41 changes: 41 additions & 0 deletions
41
lib/bindings/python/src/dynamo/logits_processing/examples/temperature.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from typing import Sequence | ||
|
|
||
| import torch | ||
|
|
||
| from dynamo.logits_processing import BaseLogitsProcessor | ||
|
|
||
|
|
||
| class TemperatureProcessor(BaseLogitsProcessor): | ||
| """ | ||
| Example logits processor that applies temperature scaling. | ||
|
|
||
| This is a simple demonstration of how to implement a logits processor | ||
| that can be used with any Dynamo backend. | ||
| """ | ||
|
|
||
| def __init__(self, temperature: float = 1.0): | ||
| """ | ||
| Args: | ||
| temperature: Scaling factor. Higher values make distribution more uniform, | ||
| lower values make it more peaked. Must be positive. | ||
| """ | ||
| if temperature <= 0: | ||
| raise ValueError("Temperature must be positive") | ||
| self.temperature = temperature | ||
|
|
||
| def __call__(self, input_ids: Sequence[int], logits: torch.Tensor): | ||
| """ | ||
| Apply temperature scaling to logits. | ||
|
|
||
| Args: | ||
| input_ids: Token IDs generated so far (unused in this simple example) | ||
| logits: Raw logits tensor from model | ||
|
|
||
| The processor is expected to modify the logits in-place. | ||
| """ | ||
| if self.temperature == 1.0: | ||
| return | ||
| logits.div_(self.temperature) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.