Skip to content

Commit 0cd9608

Browse files
authored
Give biased_logits the same dtype as logits
Exllama generates logits in torch Half-dtype, but Outlines requires the Float-dtype. This small change converts the logits to the required dtype (whatever that might be), solving issue #583. Tested with Exllama on the example code on the github front page, and #583 is resolved.
1 parent eb692f6 commit 0cd9608

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

outlines/generate/generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def bias_logits(logits: torch.Tensor, allowed_token_ids: List) -> torch.Tensor:
282282
A view of the original logits tensor where some values are masked.
283283
284284
"""
285-
biased_logits = torch.full(logits.shape, -math.inf, device=logits.device)
285+
biased_logits = torch.full_like(logits, -math.inf, device=logits.device)
286286
for i, ids in enumerate(allowed_token_ids):
287287
biased_logits[i, ids] = logits[i, ids]
288288
return biased_logits

0 commit comments

Comments
 (0)