Skip to content

Commit 29d9344

Browse files
committed
Revert "[Sampler] Adapt to FlashInfer 0.2.3 sampler API (vllm-project#15777)"
This reverts commit 7fdfa01. Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent 70812a9 commit 29d9344

File tree

7 files changed

+89
-123
lines changed

7 files changed

+89
-123
lines changed

docker/Dockerfile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,9 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
255255
RUN --mount=type=cache,target=/root/.cache/uv \
256256
. /etc/environment && \
257257
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
258-
# uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.4/flashinfer_python-0.2.4+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \
259258
# TESTING: install FlashInfer from source to test 2.7.0 final RC
260259
FLASHINFER_ENABLE_AOT=1 TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX' \
261-
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.4" ; \
260+
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.2.post1" ; \
262261
fi
263262
COPY examples examples
264263
COPY benchmarks benchmarks

tests/samplers/test_rejection_sampler.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
169169
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
170170
@pytest.mark.parametrize("n_rep", [100])
171171
@pytest.mark.parametrize("device", CUDA_DEVICES)
172-
# @pytest.mark.parametrize("use_flashinfer", [True, False])
173-
# Not testing FlashInfer now, since 0.2.3 API removed the ability
174-
# to pass in uniform samples.
175-
@pytest.mark.parametrize("use_flashinfer", [False])
172+
@pytest.mark.parametrize("use_flashinfer", [True, False])
176173
@torch.inference_mode()
177174
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
178175
frac_seeded: float, n_rep: int, device: str,
@@ -217,10 +214,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
217214
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
218215
@pytest.mark.parametrize("batch_size", [3, 8, 32, 128])
219216
@pytest.mark.parametrize("device", CUDA_DEVICES)
220-
# @pytest.mark.parametrize("use_flashinfer", [True, False])
221-
# Not testing FlashInfer now, since 0.2.3 API removed the ability
222-
# to pass in uniform samples.
223-
@pytest.mark.parametrize("use_flashinfer", [False])
217+
@pytest.mark.parametrize("use_flashinfer", [True, False])
224218
@torch.inference_mode()
225219
def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int,
226220
device: str, use_flashinfer: bool):
@@ -290,10 +284,6 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
290284
Test the flashinfer and nonflashinfer backend generate
291285
the same output metrics.
292286
"""
293-
294-
pytest.skip("Not testing FlashInfer now, since 0.2.3 API removed "
295-
"the ability to pass in uniform samples.")
296-
297287
torch.set_default_device(device)
298288
torch.manual_seed(0)
299289
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)

tests/samplers/test_sampler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,8 +647,6 @@ def test_flashinfer_fallback(seed: int, device: str):
647647
if not envs.VLLM_USE_FLASHINFER_SAMPLER:
648648
pytest.skip("Flashinfer sampler is disabled")
649649

650-
pytest.skip("After FlashInfer 0.2.3, sampling will never fail")
651-
652650
set_random_seed(seed)
653651
torch.set_default_device(device)
654652
batch_size = random.randint(1, 256)
Lines changed: 1 addition & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,14 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
import pytest
32
import torch
4-
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
53
from torch import Generator
64

7-
from vllm.platforms import current_platform
8-
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
9-
is_flashinfer_available)
5+
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
106

117
DEVICE = "cuda"
128

139
BATCH_SIZE = 1024
1410
VOCAB_SIZE = 128 * 1024
1511

16-
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
17-
1812

1913
def test_topk_impl_equivalance():
2014

@@ -41,67 +35,3 @@ def test_topk_impl_equivalance():
4135
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
4236

4337
assert torch.allclose(result1, result2)
44-
45-
46-
def test_flashinfer_sampler():
47-
'''
48-
This test verifies that the FlashInfer top-k and top-p sampling
49-
implementation produces the same results as the Python implementation.
50-
51-
NOTE: FlashInfer did not directly expose an interface for fused top-k and
52-
top-p prob renorm (it did provide fused sampling but we cannot compare
53-
sampling results due to randomness), so we will compare the probability
54-
renormed consequently by top-k and then top-p of FlashInfer implementation.
55-
'''
56-
57-
if not FLASHINFER_ENABLED:
58-
pytest.skip(
59-
"FlashInfer not installed or not available on this platform.")
60-
61-
with torch.device(DEVICE):
62-
generator = Generator(device=DEVICE).manual_seed(42)
63-
64-
# Generate random logits
65-
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
66-
67-
# Generate various top-k and top-p values
68-
k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
69-
p_values = torch.rand(
70-
(BATCH_SIZE, ),
71-
generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]
72-
73-
# Sometimes disable top-k (k=vocab_size)
74-
k_values.masked_fill_(
75-
torch.randint(0,
76-
2, (BATCH_SIZE, ),
77-
generator=generator,
78-
dtype=torch.bool), VOCAB_SIZE)
79-
80-
# Sometimes disable top-p (p=1.0)
81-
p_values.masked_fill_(
82-
torch.randint(0,
83-
2, (BATCH_SIZE, ),
84-
generator=generator,
85-
dtype=torch.bool), 1.0)
86-
87-
python_logits = apply_top_k_top_p(
88-
logits=logits.clone(),
89-
k=k_values,
90-
p=p_values,
91-
)
92-
python_probs = torch.softmax(python_logits, dim=-1)
93-
94-
# FlashInfer only exposed renorm interfaces for probs so convert first
95-
flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
96-
flashinfer_probs = top_k_renorm_probs(
97-
probs=flashinfer_probs,
98-
top_k=k_values,
99-
)
100-
flashinfer_probs = top_p_renorm_probs(
101-
probs=flashinfer_probs,
102-
top_p=p_values,
103-
)
104-
105-
# Compare the results
106-
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
107-
"FlashInfer and Python sampling implementations do not match!"

vllm/model_executor/layers/rejection_sampler.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,12 @@ def forward(
123123
# for rejection sampling
124124
if self.use_flashinfer and chain_speculative_sampling is not None:
125125
batch_size, k, _ = draft_probs.shape
126-
127-
(output_token_ids, accepted_token_num,
128-
emitted_token_num) = chain_speculative_sampling(
129-
draft_probs,
130-
draft_token_ids,
131-
target_with_bonus_probs,
132-
)
126+
uniform_samples = self._create_uniform_samples(
127+
seeded_seqs, batch_size, k, draft_probs.device)
128+
output_token_ids, accepted_token_num, emitted_token_num \
129+
= chain_speculative_sampling(
130+
draft_probs, draft_token_ids, uniform_samples,
131+
target_with_bonus_probs)
133132

134133
# num_emitted_tokens returned by flashinfer
135134
# does not include the bonus token

vllm/model_executor/layers/sampler.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""A layer that samples the next tokens from the model's outputs."""
33
import itertools
4+
import warnings
45
from collections.abc import Iterator
56
from dataclasses import dataclass
67
from importlib.util import find_spec
@@ -23,6 +24,7 @@
2324
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
2425

2526
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
27+
import flashinfer.sampling
2628
# yapf: disable
2729
from flashinfer.sampling import (
2830
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
@@ -31,10 +33,6 @@
3133
else:
3234
flashinfer_top_k_top_p_sampling = None
3335

34-
from vllm.logger import init_logger
35-
36-
logger = init_logger(__name__)
37-
3836

3937
def get_sampler() -> torch.nn.Module:
4038
if envs.VLLM_USE_V1:
@@ -547,15 +545,38 @@ def _multinomial(
547545
def _top_k_top_p_multinomial_with_flashinfer(
548546
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
549547
num_samples: int, seq_groups: Optional[list[SequenceGroupToSample]]):
548+
max_top_k_round = 32
550549
if num_samples > 1:
551550
probs = probs.repeat_interleave(num_samples, dim=0)
552551
top_ks = top_ks.repeat_interleave(num_samples)
553552
top_ps = top_ps.repeat_interleave(num_samples)
554-
batch_next_token_ids = flashinfer_top_k_top_p_sampling(
553+
batch_size = probs.shape[0]
554+
uniform_samples = torch.empty((max_top_k_round, batch_size),
555+
device=probs.device)
556+
if seq_groups is None:
557+
uniform_samples.uniform_()
558+
else:
559+
sample_idx = 0
560+
for seq_group in seq_groups:
561+
seq_ids = seq_group.seq_ids
562+
stride = len(seq_ids) * num_samples
563+
assert seq_group.generator is not None
564+
uniform_samples[:, sample_idx:sample_idx +
565+
stride].uniform_(generator=seq_group.generator)
566+
sample_idx += stride
567+
batch_next_token_ids, success = flashinfer_top_k_top_p_sampling(
555568
probs,
569+
uniform_samples,
556570
top_ks,
557571
top_ps,
558572
)
573+
if not success.all():
574+
warnings.warn("FlashInfer rejection sampling failed, fallback.",
575+
stacklevel=1)
576+
probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks)
577+
probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps)
578+
batch_next_token_ids = flashinfer.sampling.sampling_from_probs(
579+
probs, uniform_samples[0])
559580
return batch_next_token_ids.view(-1, num_samples)
560581

561582

@@ -691,14 +712,19 @@ def _sample_with_torch(
691712
seq_groups)
692713

693714
if flashinfer_top_k_top_p_sampling is not None:
694-
logger.warning("FlashInfer 0.2.3+ does not support "
695-
"per-request generators. Falling back to "
696-
"PyTorch-native implementation.")
697-
698-
multinomial_samples[sampling_type] = _multinomial(
699-
probs[long_sample_indices],
700-
max_n_in_batch,
701-
seq_groups=seq_groups_arg)
715+
multinomial_samples[
716+
sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
717+
probs[long_sample_indices],
718+
sampling_tensors.top_ks[long_sample_indices],
719+
sampling_tensors.top_ps[long_sample_indices],
720+
max_n_in_batch,
721+
seq_groups_arg,
722+
)
723+
else:
724+
multinomial_samples[sampling_type] = _multinomial(
725+
probs[long_sample_indices],
726+
max_n_in_batch,
727+
seq_groups=seq_groups_arg)
702728

703729
if sampled_token_ids_tensor is not None:
704730
# Store sampled tokens in output tensor.

vllm/v1/sample/ops/topk_topp_sampler.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,21 @@ def __init__(self):
3131
if current_platform.is_cuda():
3232
if is_flashinfer_available:
3333
flashinfer_version = flashinfer.__version__
34-
if flashinfer_version < "0.2.3":
35-
logger.warning(
36-
"FlashInfer version >= 0.2.3 required. "
37-
"Falling back to default sampling implementation.")
34+
if flashinfer_version >= "0.2.3":
35+
# FIXME(DefTruth): Currently, we have errors when using
36+
# FlashInfer>=v0.2.3 for top-p & top-k sampling. As a
37+
# workaround, we disable FlashInfer for top-p & top-k
38+
# sampling by default while FlashInfer>=v0.2.3.
39+
# The sampling API removes the success return value
40+
# of all sampling API, which is not compatible with
41+
# earlier design.
42+
# https://github.com/flashinfer-ai/flashinfer/releases/
43+
# tag/v0.2.3
44+
logger.info(
45+
"Currently, FlashInfer top-p & top-k sampling sampler "
46+
"is disabled because FlashInfer>=v0.2.3 is not "
47+
"backward compatible. Falling back to the PyTorch-"
48+
"native implementation of top-p & top-k sampling.")
3849
self.forward = self.forward_native
3950
elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
4051
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
@@ -95,11 +106,6 @@ def forward_cuda(
95106
# not needed. This is because `random_sample` does not require
96107
# CPU-GPU synchronization while `flashinfer_sample` does.
97108
return random_sample(probs, generators)
98-
if generators:
99-
logger.warning("FlashInfer 0.2.3+ does not support "
100-
"per-request generators. Falling back to "
101-
"PyTorch-native implementation.")
102-
return self.forward_native(logits, generators, k, p)
103109
return flashinfer_sample(probs, k, p, generators)
104110

105111
def forward_tpu(
@@ -274,18 +280,36 @@ def flashinfer_sample(
274280
the synchronization overhead.
275281
"""
276282
assert not (k is None and p is None)
283+
max_top_k_round = 32
284+
batch_size = probs.shape[0]
285+
uniform_samples = torch.empty((max_top_k_round, batch_size),
286+
device=probs.device)
287+
if len(generators) != batch_size:
288+
uniform_samples.uniform_()
289+
if generators:
290+
for i, generator in generators.items():
291+
uniform_samples[:, i].uniform_(generator=generator)
277292

278293
if k is None:
279294
# Top-p only.
280-
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
281-
probs, p, deterministic=True)
295+
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
296+
probs, uniform_samples, p, deterministic=True)
282297
elif p is None:
283298
# Top-k only.
284-
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
285-
probs, k, deterministic=True)
299+
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
300+
probs, uniform_samples, k, deterministic=True)
286301
else:
287302
# Both top-k and top-p.
288-
next_token_ids = (flashinfer.sampling.top_k_top_p_sampling_from_probs(
289-
probs, k, p, deterministic=True))
290-
303+
next_token_ids, success = (
304+
flashinfer.sampling.top_k_top_p_sampling_from_probs(
305+
probs, uniform_samples, k, p, deterministic=True))
306+
307+
# NOTE: CPU-GPU synchronization happens here.
308+
if not success.all():
309+
if k is not None:
310+
probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
311+
if p is not None:
312+
probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
313+
next_token_ids = flashinfer.sampling.sampling_from_probs(
314+
probs, uniform_samples[0], deterministic=True)
291315
return next_token_ids.view(-1)

0 commit comments

Comments
 (0)