@@ -24,9 +24,18 @@ class RejectionSampler(nn.Module):
2424
2525 def __init__ (self ):
2626 super ().__init__ ()
27- if current_platform .is_cuda :
27+ if current_platform .is_cuda () :
2828 if is_flashinfer_available :
2929 if envs .VLLM_USE_FLASHINFER_SAMPLER is not False :
30+ # FIXME(woosuk): Currently, we have errors when using
31+ # FlashInfer for rejection sampling. As a workaround, we
32+ # disable FlashInfer for rejection sampling by default.
33+ logger .info ("Currently, FlashInfer rejection sampler is "
34+ "disabled because of a bug. Falling back to "
35+ "the PyTorch-native implementation of "
36+ "rejection sampling." )
37+ self .forward_method = self .forward_native
38+
3039 # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
3140 # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
3241 # default it is unused). For backward compatibility, we set
@@ -35,8 +44,8 @@ def __init__(self):
3544 # None means False, while in V1, None means True. This is
3645 # why we use the condition
3746 # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
38- logger .info ("Using FlashInfer for rejection sampling." )
39- self .forward_method = self .flashinfer_sample
47+ # logger.info("Using FlashInfer for rejection sampling.")
48+ # self.forward_method = self.flashinfer_sample
4049 else :
4150 logger .warning (
4251 "FlashInfer is available, but it is not enabled. "
0 commit comments