File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed
vllm/v1/attention/backends Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -296,6 +296,12 @@ def __init__(
296296 )
297297 max_num_reqs = vllm_config .scheduler_config .max_num_seqs
298298 max_num_pages = max_num_reqs * max_num_pages_per_req
299+ speculative_config = vllm_config .speculative_config
300+ num_spec_tokens = (
301+ speculative_config .num_speculative_tokens
302+ if speculative_config is not None
303+ else 0
304+ )
299305 self .enable_cuda_graph = (
300306 self .compilation_config .cudagraph_mode .decode_mode () == CUDAGraphMode .FULL
301307 )
@@ -306,7 +312,8 @@ def __init__(
306312 int , BatchDecodeWithPagedKVCacheWrapper
307313 ] = {}
308314 self ._decode_cudagraph_max_bs = min (
309- max_num_reqs , self .compilation_config .max_capture_size
315+ (1 + num_spec_tokens ) * max_num_reqs ,
316+ self .compilation_config .max_capture_size ,
310317 )
311318
312319 self .num_qo_heads = self .model_config .get_num_attention_heads (
@@ -679,7 +686,7 @@ def build(
679686 use_cudagraph = (
680687 self .enable_cuda_graph
681688 and pure_decode
682- and num_decodes <= self ._decode_cudagraph_max_bs
689+ and num_decode_tokens <= self ._decode_cudagraph_max_bs
683690 )
684691 if use_cudagraph :
685692 num_input_tokens = self .vllm_config .pad_for_cudagraph (
You can’t perform that action at this time.
0 commit comments