We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e3f00d1 commit a90c97dCopy full SHA for a90c97d
cacheflow/models/sample.py
@@ -36,10 +36,11 @@ def forward(
36
# Use in-place division to avoid creating a new tensor.
37
logits.div_(t.unsqueeze(dim=1))
38
39
+ # We use float32 for probabilities and log probabilities.
40
# Compute the probabilities.
41
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
42
# Compute the log probabilities (before applying top-p).
- logprobs = torch.log(probs, out=logits)
43
+ logprobs = torch.log(probs)
44
45
# Apply top-p truncation.
46
top_ps = _get_top_ps(input_metadata)
0 commit comments