@@ -2711,10 +2711,11 @@ class CompilationConfig(BaseModel):
27112711 - use_inductor: whether to use inductor compilation.
27122712 - False: inductor compilation is not used. graph runs in eager.
27132713 - True: inductor compilation is used. one graph for symbolic shape
2714- is compiled. In addition, compile for cudagraph sizes that are
2715- in candidate_compile_sizes, using configurations
2716- in inductor_compile_config.
2717- - candidate_compile_sizes: sizes to compile for inductor.
2714+ is compiled. In addition, compile for compile_sizes,
2715+ using configurations in inductor_compile_config.
2716+ - compile_sizes: sizes to compile for inductor. In addition
2717+ to integers, it also supports "cudagraph_capture_sizes" to
2718+ specify the sizes for cudagraph capture.
27182719 - inductor_compile_config: additional configurations for inductor.
27192720 - None: use default configurations.
27202721 - inductor_passes: additional passes for inductor. It is a dictionary
@@ -2742,7 +2743,7 @@ class CompilationConfig(BaseModel):
27422743 splitting_ops : List [str ] = Field (default = None ) # type: ignore
27432744
27442745 use_inductor : bool = True
2745- candidate_compile_sizes : Optional [List [int ]] = Field (default = None )
2746+ compile_sizes : Optional [List [Union [ int , str ] ]] = Field (default = None )
27462747 inductor_compile_config : Dict = Field (default_factory = dict )
27472748 inductor_passes : Dict [str , str ] = Field (default_factory = dict )
27482749
@@ -2790,8 +2791,6 @@ def model_post_init(self, __context: Any) -> None:
27902791 pass_config : PassConfig = Field (default_factory = PassConfig )
27912792
27922793 # not configurable, computed after init
2793- compile_sizes : List [int ] = PrivateAttr
2794- capture_sizes : List [int ] = PrivateAttr
27952794 max_capture_size : int = PrivateAttr
27962795 local_cache_dir : str = PrivateAttr # local cache dir for each rank
27972796 # optimization:
@@ -2918,43 +2917,47 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
29182917 from vllm .compilation .backends import VllmBackend
29192918 return VllmBackend (vllm_config )
29202919
2921- def init_with_cudagraph_sizes (self , sizes_to_specialize : List [int ]):
2920+ def init_with_cudagraph_sizes (self ,
2921+ cudagraph_capture_sizes : List [int ]) -> None :
29222922 """To complete the initialization of config,
29232923 we need to know the cudagraph sizes."""
29242924
29252925 if self .cudagraph_capture_sizes is None :
2926- self .capture_sizes = sizes_to_specialize
2926+ self .cudagraph_capture_sizes = cudagraph_capture_sizes
29272927 else :
2928- self .capture_sizes = self .cudagraph_capture_sizes
2928+ # de-duplicate the sizes provided by the config
2929+ self .cudagraph_capture_sizes = list (
2930+ set (self .cudagraph_capture_sizes ))
29292931 logger .info (("cudagraph sizes specified by model runner"
29302932 " %s is overridden by config %s" ),
2931- sizes_to_specialize , self .cudagraph_capture_sizes )
2932-
2933- if self .candidate_compile_sizes is None :
2934- self .candidate_compile_sizes = []
2935- self .compile_sizes = [
2936- x for x in self .candidate_compile_sizes if x in self .capture_sizes
2937- ]
2938- ignored_sizes = [
2939- x for x in self .candidate_compile_sizes
2940- if x not in self .capture_sizes
2941- ]
2942- if ignored_sizes :
2943- logger .warning (("candidate_compile_sizes %s are ignored "
2944- "because they are not cudagraph capture sizes." ),
2945- ignored_sizes )
2933+ cudagraph_capture_sizes , self .cudagraph_capture_sizes )
2934+
2935+ computed_compile_sizes = []
2936+ if self .compile_sizes is not None :
2937+ # de-duplicate the sizes provided by the config
2938+ self .compile_sizes = list (set (self .compile_sizes ))
2939+ for x in self .compile_sizes :
2940+ if isinstance (x , str ):
2941+ assert x == "cudagraph_capture_sizes" , \
2942+ "Unrecognized size type in compile_sizes, " \
2943+ f"expect 'cudagraph_capture_sizes', got { x } "
2944+ computed_compile_sizes .extend (self .cudagraph_capture_sizes )
2945+ else :
2946+ assert isinstance (x , int )
2947+ computed_compile_sizes .append (x )
2948+ self .compile_sizes = computed_compile_sizes # type: ignore
29462949
29472950 # sort to make sure cudagraph capture sizes are in descending order
2948- self .capture_sizes .sort (reverse = True )
2949- self .max_capture_size = self .capture_sizes [
2950- 0 ] if self .capture_sizes else 0
2951+ self .cudagraph_capture_sizes .sort (reverse = True )
2952+ self .max_capture_size = self .cudagraph_capture_sizes [
2953+ 0 ] if self .cudagraph_capture_sizes else 0
29512954
29522955 # pre-compute the mapping from batch size to padded graph size
29532956 self .bs_to_padded_graph_size = [
29542957 0 for i in range (self .max_capture_size + 1 )
29552958 ]
2956- for end , start in zip (self .capture_sizes ,
2957- self .capture_sizes [1 :] + [0 ]):
2959+ for end , start in zip (self .cudagraph_capture_sizes ,
2960+ self .cudagraph_capture_sizes [1 :] + [0 ]):
29582961 for bs in range (start , end ):
29592962 if bs == start :
29602963 self .bs_to_padded_graph_size [bs ] = start
@@ -3225,14 +3228,14 @@ def _set_cudagraph_sizes(self):
32253228 However, if users specify the cudagraph capture sizes through
32263229 compilation config, we will use the specified sizes instead.
32273230
3228- In the end, `vllm_config.compilation_config.capture_sizes` will be the
3229- final sizes to capture cudagraph (in descending order).
3231+ In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
3232+ will be the final sizes to capture cudagraph (in descending order).
32303233
32313234 During runtime, if batchsize is larger than
3232- `vllm_config.compilation_config.capture_sizes `,
3235+ `vllm_config.compilation_config.cudagraph_capture_sizes `,
32333236 no cudagraph will be used.
32343237 If the batch size is no larger than
3235- `vllm_config.compilation_config.capture_sizes `,
3238+ `vllm_config.compilation_config.cudagraph_capture_sizes `,
32363239 we can quickly find the padded graph size for a given batch size by
32373240 looking up `vllm_config.compilation_config.bs_to_padded_graph_size`.
32383241 """
0 commit comments