Skip to content

Commit 11ae016

Browse files
luccafongLucia (Lu) Fang
andauthored
[torch.compile] Passing only necessary compilation config to inductor pass config (#27041)
Signed-off-by: Lu Fang <fanglu@fb.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com>
1 parent 41d3071 commit 11ae016

File tree

4 files changed

+47
-2
lines changed

4 files changed

+47
-2
lines changed

tests/compile/test_async_tp.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,15 @@ def async_tp_pass_on_test_model(
341341
async_tp_pass = AsyncTPPass(vllm_config)
342342
backend = TestBackend(async_tp_pass)
343343

344+
assert (
345+
async_tp_pass.compilation_config.splitting_ops
346+
== vllm_config.compilation_config.splitting_ops
347+
)
348+
assert (
349+
async_tp_pass.compilation_config.use_inductor_graph_partition
350+
== vllm_config.compilation_config.use_inductor_graph_partition
351+
)
352+
344353
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
345354

346355
hidden_states = torch.randn(

tests/compile/test_config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import copy
4+
35
import pytest
46

57
from vllm.compilation.counter import compilation_counter
8+
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
69
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
710
from vllm.config.compilation import CompilationMode
811
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
@@ -25,6 +28,20 @@ def test_use_cudagraphs_dynamic():
2528
assert vllm_config.compilation_config.use_cudagraph
2629

2730

31+
def test_copy_pass():
32+
vllm_config = VllmConfig()
33+
inductor_pass = FixFunctionalizationPass(vllm_config)
34+
copied_inductor_pass = copy.deepcopy(inductor_pass)
35+
assert (
36+
copied_inductor_pass.compilation_config.use_inductor_graph_partition
37+
== vllm_config.compilation_config.use_inductor_graph_partition
38+
)
39+
assert (
40+
copied_inductor_pass.compilation_config.splitting_ops
41+
== vllm_config.compilation_config.splitting_ops
42+
)
43+
44+
2845
def test_custom_op():
2946
# proper syntax
3047
_ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])

tests/compile/test_sequence_parallelism.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,14 @@ def sequence_parallelism_pass_on_test_model(
285285

286286
noop_pass = NoOpEliminationPass(vllm_config)
287287
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
288+
assert (
289+
sequence_parallelism_pass.compilation_config.splitting_ops
290+
== vllm_config.compilation_config.splitting_ops
291+
)
292+
assert (
293+
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
294+
== vllm_config.compilation_config.use_inductor_graph_partition
295+
)
288296
func_pass = FixFunctionalizationPass(vllm_config)
289297
cleanup_pass = PostCleanupPass(vllm_config)
290298

vllm/compilation/vllm_inductor_pass.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import functools
44
import operator
55
import time
6-
import weakref
6+
from dataclasses import dataclass
77
from typing import ClassVar
88

99
import regex as re
@@ -19,6 +19,12 @@
1919
logger = init_logger(__name__)
2020

2121

22+
@dataclass
23+
class InductorCompilationConfig:
24+
splitting_ops: list[str] | None = None
25+
use_inductor_graph_partition: bool = False
26+
27+
2228
class VllmInductorPass(InductorPass):
2329
"""
2430
An inductor pass with access to vLLM PassConfig.
@@ -29,7 +35,12 @@ class VllmInductorPass(InductorPass):
2935
"""Keep track of pass index for debug dump ordering."""
3036

3137
def __init__(self, config: VllmConfig):
32-
self.compilation_config = weakref.proxy(config.compilation_config)
38+
# Get only the necessary CompilationConfig for the inductor pass, since
39+
# full `CompilationConfig` contains pointer to model which is unsafe.
40+
self.compilation_config = InductorCompilationConfig(
41+
splitting_ops=config.compilation_config.splitting_ops,
42+
use_inductor_graph_partition=config.compilation_config.use_inductor_graph_partition,
43+
)
3344
self.pass_config = config.compilation_config.pass_config
3445
self.model_dtype = config.model_config.dtype if config.model_config else None
3546
self.device = config.device_config.device if config.device_config else None

0 commit comments

Comments
 (0)