Skip to content

Commit 1d12e45

Browse files
committed
[V1][TPU] TPU-optimized top-k implementation (2nd try)
Previously we found that using torch.topk resulted in significant speed up for TPU. Turns out that's not a viable solution because the return shape of torch.topk depends on k, which means an XLA recompilation is triggered everytime k changes. Additionally, we realized that torch.scatter was the main bottleneck for the original top-k impl on TPU. This PR circumvents both problems by using a threshold-based approach to find the top-k set. The algorithm is nearly identical to that of top-p; see #15736 for more details. Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
1 parent b942cf1 commit 1d12e45

File tree

2 files changed

+37
-7
lines changed

2 files changed

+37
-7
lines changed

tests/v1/tpu/test_topk_topp_sampler.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import torch
66

77
from vllm.platforms import current_platform
8-
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu
8+
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
9+
apply_top_k_top_p_tpu)
910

1011
if not current_platform.is_tpu():
1112
pytest.skip("This test needs a TPU.", allow_module_level=True)
@@ -16,6 +17,25 @@
1617
TOLERANCE = 1e-6
1718

1819

20+
def test_topk_equivalence_to_native_impl():
21+
with torch.device(xm.xla_device()):
22+
xm.set_rng_state(seed=33)
23+
24+
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
25+
26+
# Random top-k values between 1 and 10.
27+
k = torch.randint(1, 10, (BATCH_SIZE, ))
28+
29+
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
30+
k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool),
31+
VOCAB_SIZE)
32+
33+
result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)
34+
35+
result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
36+
assert torch.allclose(result_native, result_tpu)
37+
38+
1939
def test_topp_result_sums_past_p():
2040
with torch.device(xm.xla_device()):
2141
xm.set_rng_state(seed=33)

vllm/v1/sample/ops/topk_topp_sampler.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,25 +138,35 @@ def apply_top_k_top_p_tpu(
138138
This algorithm avoids using torch.scatter which is extremely slow on TPU.
139139
This is achieved by finding a "cut-off" element in the original logit, and
140140
after thresholding the logit using this cut-off, the remaining elements
141-
shall constitute the top-p set.
141+
shall constitute the top-k/p set.
142142
143143
Note: in the case of tie (i.e. multipple cut-off elements present in the
144-
logit), all tie elements are included in the top-p set. In other words,
144+
logit), all tie elements are included in the top-k/p set. In other words,
145145
this function does not break ties. Instead, these tie tokens have equal
146146
chance of being chosen during final sampling, so we can consider the tie
147147
being broken then.
148148
"""
149+
probs = logits.softmax(dim=-1)
150+
probs_sort, _ = probs.sort(dim=-1, descending=False)
151+
149152
if k is not None:
150-
logits = apply_top_k_only(logits, k)
153+
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
154+
top_k_count = top_k_count.unsqueeze(dim=1)
155+
top_k_cutoff = probs_sort.gather(-1, top_k_count)
156+
157+
# Make sure the no top-k rows are no-op.
158+
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
159+
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
160+
161+
elements_to_discard = probs < top_k_cutoff
162+
logits.masked_fill_(elements_to_discard, -float("inf"))
151163

152164
if p is not None:
153-
probs = logits.softmax(dim=-1)
154-
probs_sort, _ = probs.sort(dim=-1, descending=False)
155165
cumprob = torch.cumsum(probs_sort, dim=-1)
156166
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
157167
top_p_mask[:, -1] = False # at least one
158168

159-
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
169+
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(dim=1)
160170
top_p_cutoff = probs_sort.gather(-1, top_p_count)
161171
elements_to_discard = probs < top_p_cutoff
162172
logits.masked_fill_(elements_to_discard, -float("inf"))

0 commit comments

Comments
 (0)