Skip to content

Commit 421aa87

Browse files
committed
fix: use cuda stream provided by trtllm
1 parent 38f931c commit 421aa87

File tree

2 files changed

+12
-12
lines changed
  • components/backends/trtllm/src/dynamo/trtllm/logits_processing
  • lib/bindings/python/src/dynamo/logits_processing/examples

2 files changed

+12
-12
lines changed

components/backends/trtllm/src/dynamo/trtllm/logits_processing/adapter.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,19 @@ def __call__(
4747
Returns:
4848
Modified logits tensor (in-place modification expected by TRT-LLM)
4949
"""
50-
print(f"Shapes: logits {logits.shape}, ids {ids}")
50+
stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr)
5151
try:
52-
for ids_req, logits_req in zip(ids, logits):
53-
if logits_req.shape[0] != 1:
54-
raise ValueError(
55-
"Logits processing with beam width > 1 is not supported"
56-
)
57-
# Remove dimension 0 from logits_req
58-
modified_logits = self.processor(ids_req, logits_req.reshape(-1))
59-
60-
# TRT-LLM expects in-place modification
61-
logits.copy_(modified_logits)
52+
with torch.cuda.stream(stream):
53+
for idx, (ids_req, logits_req) in enumerate(zip(ids, logits)):
54+
if logits_req.shape[0] != 1:
55+
raise ValueError(
56+
"Logits processing with beam width > 1 is not supported"
57+
)
58+
# Remove dimension 0 from logits_req
59+
modified_logits = self.processor(ids_req, logits_req.reshape(-1))
60+
61+
# TRT-LLM expects in-place modification
62+
logits[idx, 0, :].copy_(modified_logits)
6263

6364
except Exception as e:
6465
logger.error(f"Error in logits processor for request {req_ids}: {e}")

lib/bindings/python/src/dynamo/logits_processing/examples/hello_world.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase):
2424
self.state = 0
2525

2626
def __call__(self, input_ids: Sequence[int], scores: torch.Tensor) -> torch.Tensor:
27-
print("Calling logits processor")
2827
mask = torch.full_like(scores, float("-inf"))
2928

3029
if self.state < len(self.token_ids):

0 commit comments

Comments
 (0)