Skip to content

Commit 32ef498

Browse files
authored
[V1] Temporarily disable FlashInfer Rejection Sampler (#14788)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent ad19c8a commit 32ef498

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

vllm/v1/sample/ops/topk_topp_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class TopKTopPSampler(nn.Module):
2222

2323
def __init__(self):
2424
super().__init__()
25-
if current_platform.is_cuda:
25+
if current_platform.is_cuda():
2626
if is_flashinfer_available:
2727
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
2828
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for

vllm/v1/sample/rejection_sampler.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,18 @@ class RejectionSampler(nn.Module):
2424

2525
def __init__(self):
2626
super().__init__()
27-
if current_platform.is_cuda:
27+
if current_platform.is_cuda():
2828
if is_flashinfer_available:
2929
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
30+
# FIXME(woosuk): Currently, we have errors when using
31+
# FlashInfer for rejection sampling. As a workaround, we
32+
# disable FlashInfer for rejection sampling by default.
33+
logger.info("Currently, FlashInfer rejection sampler is "
34+
"disabled because of a bug. Falling back to "
35+
"the PyTorch-native implementation of "
36+
"rejection sampling.")
37+
self.forward_method = self.forward_native
38+
3039
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
3140
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
3241
# default it is unused). For backward compatibility, we set
@@ -35,8 +44,8 @@ def __init__(self):
3544
# None means False, while in V1, None means True. This is
3645
# why we use the condition
3746
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
38-
logger.info("Using FlashInfer for rejection sampling.")
39-
self.forward_method = self.flashinfer_sample
47+
# logger.info("Using FlashInfer for rejection sampling.")
48+
# self.forward_method = self.flashinfer_sample
4049
else:
4150
logger.warning(
4251
"FlashInfer is available, but it is not enabled. "

0 commit comments

Comments
 (0)