@@ -154,6 +154,8 @@ class CompilationConfig:
154154 - [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
155155 - [`cudagraph_capture_sizes`]
156156 [vllm.config.CompilationConfig.cudagraph_capture_sizes]
157+ - [`max_cudagraph_capture_size`]
158+ [vllm.config.CompilationConfig.max_cudagraph_capture_size]
157159 - [`cudagraph_num_of_warmups`]
158160 [vllm.config.CompilationConfig.cudagraph_num_of_warmups]
159161 - [`cudagraph_copy_inputs`]
@@ -327,18 +329,16 @@ class CompilationConfig:
327329 more modes may be added.
328330 """
329331 use_cudagraph : bool = True
330- """Whether to use cudagraph inside compilation.
331- - False: cudagraph inside compilation is not used.
332+ """Whether to use cudagraph inside compilation:
333+
334+ - False: cudagraph inside compilation is not used.\n
332335 - True: cudagraph inside compilation is used. It requires
333336 that all input buffers have fixed addresses, and all
334337 splitting ops write their outputs to input buffers.
335- In the vLLM V1 Engine, this flag only applies for
336- CompilationMode.VLLM_COMPILE (aka -O3).
337- Note that this is orthogonal to the cudagraph capture logic
338- outside of compilation.
338+
339339 Warning: This flag is deprecated and will be removed in the next major or
340- minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=PIECEWISE
341- instead.
340+ minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=FULL_AND
341+ _PIECEWISE instead.
342342 """
343343 cudagraph_num_of_warmups : int = 0
344344 """Number of warmup runs for cudagraph.
@@ -398,8 +398,22 @@ class CompilationConfig:
398398 pass_config : PassConfig = field (default_factory = PassConfig )
399399 """Custom inductor passes, see PassConfig for more details"""
400400
401- max_capture_size : int = field (default = None , init = False ) # type: ignore
402- """not configurable, computed after init"""
401+ max_cudagraph_capture_size : int | None = field (default = None )
402+ """The maximum cudagraph capture size.
403+
404+ If cudagraph_capture_sizes is specified, this will be set to the largest
405+ size in that list (or checked for consistency if specified). If
406+ cudagraph_capture_sizes is not specified, the list of sizes is generated
407+ automatically following the pattern:
408+
409+ [1, 2, 4] + list(range(8, 256, 8)) + list(
410+ range(256, max_cudagraph_capture_size + 1, 16))
411+
412+ If not specified, max_cudagraph_capture_size is set to min(max_num_seqs*2,
413+ 512) by default. This voids OOM in tight memory scenarios with small
414+ max_num_seqs, and prevents capture of many large graphs (>512) that would
415+ greatly increase startup time with limited performance benefit.
416+ """
403417 local_cache_dir : str = field (default = None , init = False ) # type: ignore
404418 """local cache dir for each rank"""
405419 bs_to_padded_graph_size : list [int ] = field (
@@ -408,7 +422,7 @@ class CompilationConfig:
408422 )
409423 """optimization:
410424 Intuitively, bs_to_padded_graph_size should be dict[int, int].
411- since we know all keys are in a range [0, max_capture_size ],
425+ since we know all keys are in a range [0, max_cudagraph_capture_size ],
412426 we can optimize it to list[int] for better lookup performance."""
413427
414428 # keep track of enabled and disabled custom ops
@@ -672,25 +686,12 @@ def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
672686
673687 return VllmBackend (vllm_config )
674688
675- def init_with_cudagraph_sizes (self , cudagraph_capture_sizes : list [int ]) -> None :
676- """To complete the initialization of config,
677- we need to know the cudagraph sizes."""
678-
679- if self .cudagraph_capture_sizes is None :
680- self .cudagraph_capture_sizes = cudagraph_capture_sizes
681- else :
682- # de-duplicate the sizes provided by the config
683- dedup_sizes = list (set (self .cudagraph_capture_sizes ))
684- if len (dedup_sizes ) < len (self .cudagraph_capture_sizes ):
685- logger .info (
686- (
687- "cudagraph sizes specified by model runner"
688- " %s is overridden by config %s"
689- ),
690- cudagraph_capture_sizes ,
691- dedup_sizes ,
692- )
693- self .cudagraph_capture_sizes = dedup_sizes
689+ def post_init_cudagraph_sizes (self ) -> None :
690+ """To complete the initialization after cudagraph related
691+ configs are set. This includes:
692+ - initialize compile_sizes
693+ - pre-compute the mapping bs_to_padded_graph_size
694+ """
694695
695696 computed_compile_sizes = []
696697 if self .compile_sizes is not None :
@@ -708,23 +709,24 @@ def init_with_cudagraph_sizes(self, cudagraph_capture_sizes: list[int]) -> None:
708709 computed_compile_sizes .append (x )
709710 self .compile_sizes = computed_compile_sizes # type: ignore
710711
711- # sort to make sure cudagraph capture sizes are in descending order
712- self .cudagraph_capture_sizes .sort (reverse = True )
713- self .max_capture_size = (
714- self .cudagraph_capture_sizes [0 ] if self .cudagraph_capture_sizes else 0
715- )
712+ # make sure the sizes are in ascending order
713+ self .cudagraph_capture_sizes .sort ()
714+ if self .cudagraph_capture_sizes :
715+ assert self .cudagraph_capture_sizes [- 1 ] == self .max_cudagraph_capture_size
716716
717717 # pre-compute the mapping from batch size to padded graph size
718- self .bs_to_padded_graph_size = [0 for i in range (self .max_capture_size + 1 )]
718+ self .bs_to_padded_graph_size = [
719+ 0 for i in range (self .max_cudagraph_capture_size + 1 )
720+ ]
719721 for end , start in zip (
720- self .cudagraph_capture_sizes , self .cudagraph_capture_sizes [1 :] + [0 ]
722+ self .cudagraph_capture_sizes + [self .max_cudagraph_capture_size + 1 ],
723+ [0 ] + self .cudagraph_capture_sizes ,
721724 ):
722725 for bs in range (start , end ):
723726 if bs == start :
724727 self .bs_to_padded_graph_size [bs ] = start
725728 else :
726729 self .bs_to_padded_graph_size [bs ] = end
727- self .bs_to_padded_graph_size [self .max_capture_size ] = self .max_capture_size
728730
729731 def set_splitting_ops_for_v1 (self ):
730732 # NOTE: this function needs to be called only when mode is
0 commit comments