Skip to content

Commit c7eb537

Browse files
njhillMu Huai
authored andcommitted
[V1][Sampler] Faster top-k only implementation (vllm-project#15478)
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
1 parent 1330b31 commit c7eb537

File tree

3 files changed

+91
-5
lines changed

3 files changed

+91
-5
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import torch
3+
from torch import Generator
4+
5+
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
6+
7+
DEVICE = "cuda"
8+
9+
BATCH_SIZE = 1024
10+
VOCAB_SIZE = 128 * 1024
11+
12+
13+
def test_topk_impl_equivalance():
14+
15+
with torch.device(DEVICE):
16+
generator = Generator(device=DEVICE).manual_seed(33)
17+
18+
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
19+
20+
# Random top-k values between 1 and 9.
21+
k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)
22+
23+
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
24+
k.masked_fill_(
25+
torch.randint(0,
26+
2, (BATCH_SIZE, ),
27+
generator=generator,
28+
dtype=bool), VOCAB_SIZE)
29+
30+
# Top-k only implementation
31+
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
32+
33+
# Top-p + top-k
34+
no_op_top_p = torch.tensor([1.0])
35+
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
36+
37+
assert torch.allclose(result1, result2)

vllm/v1/sample/ops/topk_topp_sampler.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919

2020

2121
class TopKTopPSampler(nn.Module):
22+
"""
23+
Module that performs optional top-k and top-p filtering followed by
24+
weighted random sampling of logits.
25+
26+
Implementations may update the logits tensor in-place.
27+
"""
2228

2329
def __init__(self):
2430
super().__init__()
@@ -84,7 +90,11 @@ def forward_native(
8490
k: Optional[torch.Tensor],
8591
p: Optional[torch.Tensor],
8692
) -> torch.Tensor:
87-
"""PyTorch-native implementation of top-k and top-p sampling."""
93+
"""
94+
PyTorch-native implementation of top-k and top-p sampling.
95+
96+
The logits tensor may be updated in-place.
97+
"""
8898
logits = apply_top_k_top_p(logits, k, p)
8999
probs = logits.softmax(dim=-1, dtype=torch.float32)
90100
return random_sample(probs, generators)
@@ -136,10 +146,18 @@ def apply_top_k_top_p(
136146
) -> torch.Tensor:
137147
"""Apply top-k and top-p masks to the logits.
138148
139-
This function sorts the logits tensor, which can be slow for large batches.
149+
If a top-p is used, this function will sort the logits tensor,
150+
which can be slow for large batches.
151+
152+
The logits tensor may be updated in-place.
140153
"""
141-
if k is None and p is None:
142-
return logits
154+
if p is None:
155+
if k is None:
156+
return logits
157+
158+
# Avoid sorting vocab for top-k only case.
159+
return apply_top_k_only(logits, k)
160+
143161
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
144162

145163
if k is not None:
@@ -153,7 +171,7 @@ def apply_top_k_top_p(
153171
if p is not None:
154172
# Apply top-p.
155173
probs_sort = logits_sort.softmax(dim=-1)
156-
probs_sum = probs_sort.cumsum(dim=-1)
174+
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
157175
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
158176
# at least one
159177
top_p_mask[:, -1] = False
@@ -164,6 +182,31 @@ def apply_top_k_top_p(
164182
return logits
165183

166184

185+
def apply_top_k_only(
186+
logits: torch.Tensor,
187+
k: torch.Tensor,
188+
) -> torch.Tensor:
189+
"""
190+
Apply top-k mask to the logits.
191+
192+
This implementation doesn't involve sorting the entire vocab.
193+
194+
The logits tensor may be updated in-place.
195+
"""
196+
no_top_k_mask = k == logits.shape[1]
197+
# Set non-top-k rows to 1 so that we can gather.
198+
k = k.masked_fill(no_top_k_mask, 1)
199+
max_top_k = k.max()
200+
# topk.values tensor has shape [batch_size, max_top_k].
201+
# Convert top k to 0-based index in range [0, max_top_k).
202+
k_index = k.sub_(1).unsqueeze(1)
203+
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index)
204+
# Handle non-topk rows.
205+
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
206+
logits.masked_fill_(logits < top_k_mask, -float("inf"))
207+
return logits
208+
209+
167210
def random_sample(
168211
probs: torch.Tensor,
169212
generators: dict[int, torch.Generator],

vllm/v1/sample/sampler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ def sample(
8787
logits: torch.Tensor,
8888
sampling_metadata: SamplingMetadata,
8989
) -> torch.Tensor:
90+
"""Sample logits based on sampling metadata.
91+
92+
The various logits processing functions called in this method
93+
may update the logits tensor in-place.
94+
"""
95+
9096
assert not (sampling_metadata.all_greedy
9197
and sampling_metadata.all_random)
9298
if sampling_metadata.all_random:

0 commit comments

Comments
 (0)