Skip to content

Commit 3cbbd4c

Browse files
committed
chore: refactor logits processors to be in place
1 parent f5744d8 commit 3cbbd4c

File tree

4 files changed

+18
-21
lines changed

4 files changed

+18
-21
lines changed

components/backends/trtllm/src/dynamo/trtllm/logits_processing/adapter.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,17 @@ def __call__(
5050
stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr)
5151
try:
5252
with torch.cuda.stream(stream):
53-
for idx, (ids_req, logits_req) in enumerate(zip(ids, logits)):
54-
if logits_req.shape[0] != 1:
55-
raise ValueError(
56-
"Logits processing with beam width > 1 is not supported"
57-
)
58-
# Remove dimension 0 from logits_req
59-
modified_logits = self.processor(ids_req, logits_req.reshape(-1))
60-
61-
# TRT-LLM expects in-place modification
62-
logits[idx, 0, :].copy_(modified_logits)
53+
if logits.shape[0] != 1:
54+
raise ValueError(
55+
f"This logits adapter only supports per-request logits processing. "
56+
f"Received logits with batch size {logits.shape[0]} expected 1"
57+
)
58+
if logits.shape[1] != 1:
59+
raise ValueError(
60+
"Logits processing with beam width > 1 is not supported"
61+
)
62+
# Call the processor which modifies the logits in-place
63+
self.processor(ids[0], logits[0, 0, :])
6364

6465
except Exception as e:
6566
logger.error(f"Error in logits processor for request {req_ids}: {e}")

lib/bindings/python/src/dynamo/logits_processing/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,14 @@ def __call__(
2626
self,
2727
input_ids: Sequence[int],
2828
logits: torch.Tensor,
29-
) -> torch.Tensor:
29+
):
3030
"""
3131
Process the logits for the next token prediction.
3232
3333
Args:
3434
input_ids: The input token IDs generated so far.
3535
logits: The raw logits for the next token. Shape: (vocab_size,)
3636
37-
Returns:
38-
A tensor with the same shape, dtype, and device as `logits`.
37+
The processor is expected to modify the logits in-place.
3938
"""
4039
...

lib/bindings/python/src/dynamo/logits_processing/examples/hello_world.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase):
2323
self.eos_id = tokenizer.eos_token_id
2424
self.state = 0
2525

26-
def __call__(self, input_ids: Sequence[int], scores: torch.Tensor) -> torch.Tensor:
26+
def __call__(self, input_ids: Sequence[int], scores: torch.Tensor):
2727
mask = torch.full_like(scores, float("-inf"))
2828

2929
if self.state < len(self.token_ids):
@@ -36,5 +36,3 @@ def __call__(self, input_ids: Sequence[int], scores: torch.Tensor) -> torch.Tens
3636
# The `scores` tensor *must* also be modified in-place
3737
scores.add_(mask)
3838
self.state += 1
39-
40-
return scores

lib/bindings/python/src/dynamo/logits_processing/examples/temperature.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,16 @@ def __init__(self, temperature: float = 1.0):
2626
raise ValueError("Temperature must be positive")
2727
self.temperature = temperature
2828

29-
def __call__(self, input_ids: Sequence[int], logits: torch.Tensor) -> torch.Tensor:
29+
def __call__(self, input_ids: Sequence[int], logits: torch.Tensor):
3030
"""
3131
Apply temperature scaling to logits.
3232
3333
Args:
3434
input_ids: Token IDs generated so far (unused in this simple example)
3535
logits: Raw logits tensor from model
3636
37-
Returns:
38-
Temperature-scaled logits tensor
37+
The processor is expected to modify the logits in-place.
3938
"""
4039
if self.temperature == 1.0:
41-
return logits
42-
return logits / self.temperature
40+
return
41+
logits.div_(self.temperature)

0 commit comments

Comments
 (0)