Skip to content

Commit 6efb195

Browse files
authored
[V1] Fix: make sure k_index is int64 for apply_top_k_only (#15907)
Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
1 parent 24b7fb4 commit 6efb195

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

vllm/v1/sample/ops/topk_topp_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def apply_top_k_only(
200200
# topk.values tensor has shape [batch_size, max_top_k].
201201
# Convert top k to 0-based index in range [0, max_top_k).
202202
k_index = k.sub_(1).unsqueeze(1)
203-
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index)
203+
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
204204
# Handle non-topk rows.
205205
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
206206
logits.masked_fill_(logits < top_k_mask, -float("inf"))

0 commit comments

Comments
 (0)