Skip to content

Commit 5493d3c

Browse files
Lucia (Lu) Fangfacebook-github-bot
authored andcommitted
Passing only necessary compilation config to inductor pass config (#27041)
Summary: Pull Request resolved: #27041 we should not pass the weakref to compilation_config, which include static_forward_context that will holds the pointers to the model layers (e.g. moe, attention), which is dangerous, as this will be passed as config to torch.compile Test Plan: local tests Differential Revision: D84790018
1 parent 7bb736d commit 5493d3c

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

vllm/compilation/vllm_inductor_pass.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,36 @@
33
import functools
44
import operator
55
import time
6-
import weakref
76
from typing import ClassVar
87

98
import regex as re
109
import torch
10+
from dataclasses import dataclass
1111
from torch._dynamo.utils import lazy_format_graph_code
1212
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
1313

14-
from vllm.config import VllmConfig
14+
from vllm.config import CompilationConfig, VllmConfig
1515
from vllm.logger import init_logger
1616

1717
from .inductor_pass import InductorPass
1818

1919
logger = init_logger(__name__)
2020

21+
@dataclass
22+
class SimplifiedCompilationConfig:
23+
splitting_ops: list[str] | None = None
24+
use_inductor_graph_partition: bool = False
25+
compile_sizes: list[int | str] | None = None
26+
27+
def copy_necessary_config_for_pass(config: CompilationConfig) -> SimplifiedCompilationConfig:
28+
"""Get only the necessary CompilationConfig for the inductor pass,
29+
since full`CompilationConfig` contains pointer to model which is unsafe.
30+
"""
31+
return SimplifiedCompilationConfig(
32+
splitting_ops=config.splitting_ops,
33+
use_inductor_graph_partition=config.use_inductor_graph_partition,
34+
compile_sizes=config.compile_sizes,
35+
)
2136

2237
class VllmInductorPass(InductorPass):
2338
"""
@@ -29,7 +44,7 @@ class VllmInductorPass(InductorPass):
2944
"""Keep track of pass index for debug dump ordering."""
3045

3146
def __init__(self, config: VllmConfig):
32-
self.compilation_config = weakref.proxy(config.compilation_config)
47+
self.compilation_config = copy_necessary_config_for_pass(config.compilation_config)
3348
self.pass_config = config.compilation_config.pass_config
3449
self.model_dtype = config.model_config.dtype if config.model_config else None
3550
self.device = config.device_config.device if config.device_config else None

0 commit comments

Comments
 (0)