Skip to content

Commit a90c97d

Browse files
authored
Use FP32 for log probabilities (#19)
1 parent e3f00d1 commit a90c97d

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

cacheflow/models/sample.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@ def forward(
3636
# Use in-place division to avoid creating a new tensor.
3737
logits.div_(t.unsqueeze(dim=1))
3838

39+
# We use float32 for probabilities and log probabilities.
3940
# Compute the probabilities.
4041
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
4142
# Compute the log probabilities (before applying top-p).
42-
logprobs = torch.log(probs, out=logits)
43+
logprobs = torch.log(probs)
4344

4445
# Apply top-p truncation.
4546
top_ps = _get_top_ps(input_metadata)

0 commit comments

Comments
 (0)