1111from torch ._dynamo .utils import lazy_format_graph_code
1212from torch ._inductor .pattern_matcher import PatternMatcherPass , PatternPrettyPrinter
1313
14- from vllm .config import CompilationConfig , VllmConfig
14+ from vllm .config import VllmConfig
1515from vllm .logger import init_logger
1616
1717from .inductor_pass import InductorPass
2020
2121
2222@dataclass
23- class SimplifiedCompilationConfig :
23+ class InductorCompilationConfig :
2424 splitting_ops : list [str ] | None = None
2525 use_inductor_graph_partition : bool = False
2626 compile_sizes : list [int | str ] | None = None
2727
2828
29- def copy_necessary_config_for_pass (
30- config : CompilationConfig ,
31- ) -> SimplifiedCompilationConfig :
32- """Get only the necessary CompilationConfig for the inductor pass, since
33- full `CompilationConfig` contains pointer to model which is unsafe.
34- """
35- return SimplifiedCompilationConfig (
36- splitting_ops = config .splitting_ops ,
37- use_inductor_graph_partition = config .use_inductor_graph_partition ,
38- compile_sizes = config .compile_sizes ,
39- )
40-
41-
4229class VllmInductorPass (InductorPass ):
4330 """
4431 An inductor pass with access to vLLM PassConfig.
@@ -49,8 +36,12 @@ class VllmInductorPass(InductorPass):
4936 """Keep track of pass index for debug dump ordering."""
5037
5138 def __init__ (self , config : VllmConfig ):
52- self .compilation_config = copy_necessary_config_for_pass (
53- config .compilation_config
39+ # Get only the necessary CompilationConfig for the inductor pass, since
40+ # full `CompilationConfig` contains pointer to model which is unsafe.
41+ self .compilation_config = InductorCompilationConfig (
42+ splitting_ops = config .splitting_ops ,
43+ use_inductor_graph_partition = config .use_inductor_graph_partition ,
44+ compile_sizes = config .compile_sizes ,
5445 )
5546 self .pass_config = config .compilation_config .pass_config
5647 self .model_dtype = config .model_config .dtype if config .model_config else None
0 commit comments