Skip to content

Commit f1e396f

Browse files
NickLucchehyeygit
authored andcommitted
[V1][TPU] Enable Top K (vllm-project#15489)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Hyesoo Yang <hyeygit@gmail.com> Co-authored-by: Hyesoo Yang <hyeygit@gmail.com>
1 parent 626bc19 commit f1e396f

File tree

6 files changed

+69
-26
lines changed

6 files changed

+69
-26
lines changed

tests/v1/tpu/test_sampler.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
import random
3+
24
import pytest
35

46
from vllm import LLM, envs
@@ -39,3 +41,19 @@ def test_sampler_different(model_name: str):
3941
# Unsupported `seed` param.
4042
sampling_params = SamplingParams(temperature=0.3, seed=42)
4143
output2 = llm.generate(prompts, sampling_params)
44+
45+
# Batch-case with TopK
46+
for B in [4, 16]:
47+
p = prompts * B
48+
sampling_params = [
49+
SamplingParams(
50+
temperature=0.1,
51+
min_p=0.8,
52+
max_tokens=64,
53+
# Vary number of ks
54+
top_k=random.randint(4, 12)) for _ in range(B)
55+
]
56+
# Make sure first two reqs have the same K
57+
sampling_params[0] = sampling_params[1]
58+
output = llm.generate(p, sampling_params)
59+
assert output[0].outputs[0].text == output[1].outputs[0].text

tests/v1/tpu/test_topk_topp_sampler.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import torch
66

77
from vllm.platforms import current_platform
8-
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu
8+
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
9+
apply_top_k_top_p_tpu)
910

1011
if not current_platform.is_tpu():
1112
pytest.skip("This test needs a TPU.", allow_module_level=True)
@@ -16,6 +17,25 @@
1617
TOLERANCE = 1e-6
1718

1819

20+
def test_topk_equivalence_to_native_impl():
21+
with torch.device(xm.xla_device()):
22+
xm.set_rng_state(seed=33)
23+
24+
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
25+
26+
# Random top-k values between 1 and 10.
27+
k = torch.randint(1, 10, (BATCH_SIZE, ))
28+
29+
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
30+
k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool),
31+
VOCAB_SIZE)
32+
33+
result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)
34+
35+
result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
36+
assert torch.allclose(result_native, result_tpu)
37+
38+
1939
def test_topp_result_sums_past_p():
2040
with torch.device(xm.xla_device()):
2141
xm.set_rng_state(seed=33)

vllm/envs.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@
103103
VLLM_DP_MASTER_PORT: int = 0
104104
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
105105
VLLM_V0_USE_OUTLINES_CACHE: bool = False
106-
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
107106
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
108107
VLLM_USE_DEEP_GEMM: bool = False
109108
VLLM_XGRAMMAR_CACHE_MB: int = 0
@@ -685,11 +684,6 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
685684
"VLLM_V0_USE_OUTLINES_CACHE":
686685
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
687686

688-
# If set, disables TPU-specific optimization for top-k & top-p sampling
689-
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
690-
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
691-
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
692-
693687
# Gap between padding buckets for the forward pass. So we have
694688
# 8, we will run forward pass with [16, 24, 32, ...].
695689
"VLLM_TPU_BUCKET_PADDING_GAP":

vllm/v1/sample/ops/topk_topp_sampler.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,7 @@ def __init__(self):
7272
"best performance, please install FlashInfer.")
7373
self.forward = self.forward_native
7474
elif current_platform.is_tpu():
75-
if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION:
76-
logger.warning(
77-
"TPU-specific optimization for top-k & top-p sampling are "
78-
"disabled, falling back to PyTorch-native implementation "
79-
"which could be very slow.")
80-
self.forward = self.forward_native
81-
else:
82-
self.forward = self.forward_tpu
75+
self.forward = self.forward_tpu
8376
else:
8477
self.forward = self.forward_native
8578

@@ -146,12 +139,22 @@ def apply_top_k_top_p_tpu(
146139
chance of being chosen during final sampling, so we can consider the tie
147140
being broken then.
148141
"""
142+
probs = logits.softmax(dim=-1)
143+
probs_sort, _ = probs.sort(dim=-1, descending=False)
144+
149145
if k is not None:
150-
logits = apply_top_k_only(logits, k)
146+
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
147+
top_k_count = top_k_count.unsqueeze(dim=1)
148+
top_k_cutoff = probs_sort.gather(-1, top_k_count)
149+
150+
# Make sure the no top-k rows are no-op.
151+
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
152+
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
153+
154+
elements_to_discard = probs < top_k_cutoff
155+
logits.masked_fill_(elements_to_discard, -float("inf"))
151156

152157
if p is not None:
153-
probs = logits.softmax(dim=-1)
154-
probs_sort, _ = probs.sort(dim=-1, descending=False)
155158
cumprob = torch.cumsum(probs_sort, dim=-1)
156159
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
157160
top_p_mask[:, -1] = False # at least one
@@ -224,7 +227,7 @@ def apply_top_k_only(
224227
max_top_k = k.max()
225228
# topk.values tensor has shape [batch_size, max_top_k].
226229
# Convert top k to 0-based index in range [0, max_top_k).
227-
k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1)
230+
k_index = k.sub_(1).unsqueeze(1)
228231
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
229232
# Handle non-topk rows.
230233
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))

vllm/v1/sample/tpu/metadata.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
temperature=-1.0,
1111
min_p=0.0,
1212
# strictly disabled for now
13-
# top_k=-1,
13+
top_k=0,
1414
# top_p=0.0,
1515
# frequency_penalties=0.0,
1616
# presence_penalties=0.0,
@@ -99,11 +99,13 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
9999

100100
fill_slice(input_batch.temperature_cpu_tensor,
101101
DEFAULT_SAMPLING_PARAMS["temperature"])
102-
# TODO Temporarily disabled until sampling options are enabled
103-
# fill_slice(input_batch.top_p_cpu_tensor)
104-
# fill_slice(input_batch.top_k_cpu_tensor)
105102
fill_slice(input_batch.min_p_cpu_tensor,
106103
DEFAULT_SAMPLING_PARAMS["min_p"])
104+
fill_slice(input_batch.top_k_cpu_tensor,
105+
DEFAULT_SAMPLING_PARAMS["top_k"])
106+
# TODO Temporarily disabled until sampling options are enabled
107+
# fill_slice(input_batch.top_p_cpu_tensor,
108+
# DEFAULT_SAMPLING_PARAMS["top_p"])
107109

108110
# Slice persistent device tensors to a fixed pre-compiled padded shape.
109111
return cls(
@@ -112,6 +114,7 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
112114
all_greedy=input_batch.all_greedy,
113115
# TODO enable more and avoid returning None values
114116
top_p=None, # input_batch.top_p[:padded_num_reqs],
115-
top_k=None, # input_batch.top_k[:padded_num_reqs],
117+
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
118+
xla_device),
116119
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
117120
xla_device))

vllm/v1/worker/tpu_model_runner.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -920,14 +920,19 @@ def _precompile_select_hidden_states(self) -> None:
920920
device=self.device)
921921
torch._dynamo.mark_dynamic(indices, 0)
922922
self.select_hidden_states(dummy_hidden, indices)
923-
logger.info(" -- num_tokens: %d", num_tokens)
923+
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
924+
num_reqs)
925+
# Requests can't be more than tokens. But do compile for the
926+
# next bigger value in case num_tokens uses bucketed padding.
927+
if num_reqs >= min(num_tokens, self.max_num_reqs):
928+
break
924929
xm.wait_device_ops()
925930
end = time.perf_counter()
926931
logger.info("Compilation finished in in %.2f [secs].", end - start)
927932
self._update_num_xla_graphs("select_hidden_states")
928933

929934
def _precompile_sample_from_hidden(self) -> None:
930-
logger.info("Compiling sampling with different input shapes.")
935+
logger.info("Compiling sampling with different num_reqs.")
931936
start = time.perf_counter()
932937
hsize = self.model_config.get_hidden_size()
933938
for num_reqs in self.num_reqs_paddings:

0 commit comments

Comments
 (0)