@@ -24,7 +24,24 @@ def __init__(self):
2424 super ().__init__ ()
2525 if current_platform .is_cuda ():
2626 if is_flashinfer_available :
27- if envs .VLLM_USE_FLASHINFER_SAMPLER is not False :
27+ flashinfer_version = flashinfer .__version__
28+ if flashinfer_version >= "0.2.3" :
29+ # FIXME(DefTruth): Currently, we have errors when using
30+ # FlashInfer>=v0.2.3 for top-p & top-k sampling. As a
31+ # workaround, we disable FlashInfer for top-p & top-k
32+ # sampling by default while FlashInfer>=v0.2.3.
33+ # The sampling API removes the success return value
34+ # of all sampling API, which is not compatible with
35+ # earlier design.
36+ # https://github.com/flashinfer-ai/flashinfer/releases/
37+ # tag/v0.2.3
38+ logger .info (
39+ "Currently, FlashInfer top-p & top-k sampling sampler "
40+ "is disabled because FlashInfer>=v0.2.3 is not "
41+ "backward compatible. Falling back to the PyTorch-"
42+ "native implementation of top-p & top-k sampling." )
43+ self .forward = self .forward_native
44+ elif envs .VLLM_USE_FLASHINFER_SAMPLER is not False :
2845 # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
2946 # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
3047 # default it is unused). For backward compatibility, we set
0 commit comments