@@ -31,10 +31,21 @@ def __init__(self):
3131 if current_platform .is_cuda ():
3232 if is_flashinfer_available :
3333 flashinfer_version = flashinfer .__version__
34- if flashinfer_version < "0.2.3" :
35- logger .warning (
36- "FlashInfer version >= 0.2.3 required. "
37- "Falling back to default sampling implementation." )
34+ if flashinfer_version >= "0.2.3" :
35+ # FIXME(DefTruth): Currently, we have errors when using
36+ # FlashInfer>=v0.2.3 for top-p & top-k sampling. As a
37+ # workaround, we disable FlashInfer for top-p & top-k
38+ # sampling by default while FlashInfer>=v0.2.3.
39+ # The sampling API removes the success return value
40+ # of all sampling API, which is not compatible with
41+ # earlier design.
42+ # https://github.com/flashinfer-ai/flashinfer/releases/
43+ # tag/v0.2.3
44+ logger .info (
45+ "Currently, FlashInfer top-p & top-k sampling sampler "
46+ "is disabled because FlashInfer>=v0.2.3 is not "
47+ "backward compatible. Falling back to the PyTorch-"
48+ "native implementation of top-p & top-k sampling." )
3849 self .forward = self .forward_native
3950 elif envs .VLLM_USE_FLASHINFER_SAMPLER is not False :
4051 # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
@@ -95,11 +106,6 @@ def forward_cuda(
95106 # not needed. This is because `random_sample` does not require
96107 # CPU-GPU synchronization while `flashinfer_sample` does.
97108 return random_sample (probs , generators )
98- if generators :
99- logger .warning ("FlashInfer 0.2.3+ does not support "
100- "per-request generators. Falling back to "
101- "PyTorch-native implementation." )
102- return self .forward_native (logits , generators , k , p )
103109 return flashinfer_sample (probs , k , p , generators )
104110
105111 def forward_tpu (
@@ -274,18 +280,36 @@ def flashinfer_sample(
274280 the synchronization overhead.
275281 """
276282 assert not (k is None and p is None )
283+ max_top_k_round = 32
284+ batch_size = probs .shape [0 ]
285+ uniform_samples = torch .empty ((max_top_k_round , batch_size ),
286+ device = probs .device )
287+ if len (generators ) != batch_size :
288+ uniform_samples .uniform_ ()
289+ if generators :
290+ for i , generator in generators .items ():
291+ uniform_samples [:, i ].uniform_ (generator = generator )
277292
278293 if k is None :
279294 # Top-p only.
280- next_token_ids = flashinfer .sampling .top_p_sampling_from_probs (
281- probs , p , deterministic = True )
295+ next_token_ids , success = flashinfer .sampling .top_p_sampling_from_probs (
296+ probs , uniform_samples , p , deterministic = True )
282297 elif p is None :
283298 # Top-k only.
284- next_token_ids = flashinfer .sampling .top_k_sampling_from_probs (
285- probs , k , deterministic = True )
299+ next_token_ids , success = flashinfer .sampling .top_k_sampling_from_probs (
300+ probs , uniform_samples , k , deterministic = True )
286301 else :
287302 # Both top-k and top-p.
288- next_token_ids = (flashinfer .sampling .top_k_top_p_sampling_from_probs (
289- probs , k , p , deterministic = True ))
290-
303+ next_token_ids , success = (
304+ flashinfer .sampling .top_k_top_p_sampling_from_probs (
305+ probs , uniform_samples , k , p , deterministic = True ))
306+
307+ # NOTE: CPU-GPU synchronization happens here.
308+ if not success .all ():
309+ if k is not None :
310+ probs = flashinfer .sampling .top_k_renorm_prob (probs , k )
311+ if p is not None :
312+ probs = flashinfer .sampling .top_p_renorm_prob (probs , p )
313+ next_token_ids = flashinfer .sampling .sampling_from_probs (
314+ probs , uniform_samples [0 ], deterministic = True )
291315 return next_token_ids .view (- 1 )
0 commit comments