Skip to content

Commit 4719505

Browse files
authored
[V1][TPU] Speed up top-k on TPU by using torch.topk (#15242)
Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
1 parent 6edbfa9 commit 4719505

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

tests/v1/tpu/test_sampler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_sampler_compilation(model_name: str, monkeypatch):
3939
sampling_params = SamplingParams(
4040
temperature=0.7,
4141
# top_p=0.6, # TODO too slow!
42-
# top_k=10,
42+
top_k=10,
4343
min_p=0.2,
4444
max_tokens=16)
4545
s = time()
@@ -49,6 +49,7 @@ def test_sampler_compilation(model_name: str, monkeypatch):
4949
# Second request with different params, but for which we
5050
# compiled for in previous eager iteration.
5151
sampling_params = SamplingParams(temperature=0.1,
52+
top_k=12,
5253
min_p=0.8,
5354
max_tokens=24)
5455
s = time()

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
VLLM_DP_MASTER_PORT: int = 0
9696
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
9797
VLLM_V0_USE_OUTLINES_CACHE: bool = False
98+
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
9899

99100

100101
def get_default_cache_root():
@@ -623,6 +624,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
623624
# an environment with potentially malicious users.
624625
"VLLM_V0_USE_OUTLINES_CACHE":
625626
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
627+
628+
# If set, disables TPU-specific optimization for top-k & top-p sampling
629+
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
630+
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
631+
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
626632
}
627633

628634
# end-env-vars-definition

vllm/v1/sample/ops/topk_topp_sampler.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,14 @@ def __init__(self):
6666
"best performance, please install FlashInfer.")
6767
self.forward = self.forward_native
6868
elif current_platform.is_tpu():
69-
self.forward = self.forward_tpu
69+
if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION:
70+
logger.warning(
71+
"TPU-specific optimization for top-k & top-p sampling are "
72+
"disabled, falling back to PyTorch-native implementation "
73+
"which could be very slow.")
74+
self.forward = self.forward_native
75+
else:
76+
self.forward = self.forward_tpu
7077
else:
7178
self.forward = self.forward_native
7279

@@ -105,8 +112,19 @@ def forward_tpu(
105112
k: Optional[torch.Tensor],
106113
p: Optional[torch.Tensor],
107114
) -> torch.Tensor:
108-
# TODO Placeholder for TPU optimized topk/p kernel
109-
# logits = apply_top_k_top_p(logits, k, p)
115+
# If only top-k is specified, use pytorch's builtin topk op. This leads
116+
# to significant speed up on TPU compared to using apply_top_k_top_p.
117+
if k is not None and p is None:
118+
topk_values, topk_indices = torch.topk(logits, k, dim=-1)
119+
120+
mask = torch.ones_like(logits, dtype=torch.bool)
121+
mask.scatter_(-1, topk_indices, False)
122+
logits.masked_fill_(mask, float('-inf'))
123+
else:
124+
# TODO Placeholder for TPU optimized topp kernel
125+
# logits = apply_top_k_top_p(logits, k, p)
126+
pass
127+
110128
probs = logits.softmax(dim=-1, dtype=torch.float32)
111129
return random_sample(probs, generators)
112130

0 commit comments

Comments
 (0)