We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
k_index
apply_top_k_only
1 parent 24b7fb4 commit 6efb195Copy full SHA for 6efb195
vllm/v1/sample/ops/topk_topp_sampler.py
@@ -200,7 +200,7 @@ def apply_top_k_only(
200
# topk.values tensor has shape [batch_size, max_top_k].
201
# Convert top k to 0-based index in range [0, max_top_k).
202
k_index = k.sub_(1).unsqueeze(1)
203
- top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index)
+ top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
204
# Handle non-topk rows.
205
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
206
logits.masked_fill_(logits < top_k_mask, -float("inf"))
0 commit comments