Skip to content

Commit 6e783bc

Browse files
authored
[Bugfix] Fix CUDA graph selection bug in FlashInfer at high concurrency (#26499)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
1 parent c9d33c6 commit 6e783bc

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,12 @@ def __init__(
296296
)
297297
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
298298
max_num_pages = max_num_reqs * max_num_pages_per_req
299+
speculative_config = vllm_config.speculative_config
300+
num_spec_tokens = (
301+
speculative_config.num_speculative_tokens
302+
if speculative_config is not None
303+
else 0
304+
)
299305
self.enable_cuda_graph = (
300306
self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
301307
)
@@ -306,7 +312,8 @@ def __init__(
306312
int, BatchDecodeWithPagedKVCacheWrapper
307313
] = {}
308314
self._decode_cudagraph_max_bs = min(
309-
max_num_reqs, self.compilation_config.max_capture_size
315+
(1 + num_spec_tokens) * max_num_reqs,
316+
self.compilation_config.max_capture_size,
310317
)
311318

312319
self.num_qo_heads = self.model_config.get_num_attention_heads(
@@ -679,7 +686,7 @@ def build(
679686
use_cudagraph = (
680687
self.enable_cuda_graph
681688
and pure_decode
682-
and num_decodes <= self._decode_cudagraph_max_bs
689+
and num_decode_tokens <= self._decode_cudagraph_max_bs
683690
)
684691
if use_cudagraph:
685692
num_input_tokens = self.vllm_config.pad_for_cudagraph(

0 commit comments

Comments
 (0)