File tree Expand file tree Collapse file tree 4 files changed +47
-2
lines changed Expand file tree Collapse file tree 4 files changed +47
-2
lines changed Original file line number Diff line number Diff line change @@ -341,6 +341,15 @@ def async_tp_pass_on_test_model(
341341 async_tp_pass = AsyncTPPass (vllm_config )
342342 backend = TestBackend (async_tp_pass )
343343
344+ assert (
345+ async_tp_pass .compilation_config .splitting_ops
346+ == vllm_config .compilation_config .splitting_ops
347+ )
348+ assert (
349+ async_tp_pass .compilation_config .use_inductor_graph_partition
350+ == vllm_config .compilation_config .use_inductor_graph_partition
351+ )
352+
344353 model = test_model_cls (hidden_size , dtype ) # Pass dtype to model constructor
345354
346355 hidden_states = torch .randn (
Original file line number Diff line number Diff line change 11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+ import copy
4+
35import pytest
46
57from vllm .compilation .counter import compilation_counter
8+ from vllm .compilation .fix_functionalization import FixFunctionalizationPass
69from vllm .config import CompilationConfig , CUDAGraphMode , VllmConfig
710from vllm .config .compilation import CompilationMode
811from vllm .utils import _is_torch_equal_or_newer , is_torch_equal_or_newer
@@ -25,6 +28,20 @@ def test_use_cudagraphs_dynamic():
2528 assert vllm_config .compilation_config .use_cudagraph
2629
2730
31+ def test_copy_pass ():
32+ vllm_config = VllmConfig ()
33+ inductor_pass = FixFunctionalizationPass (vllm_config )
34+ copied_inductor_pass = copy .deepcopy (inductor_pass )
35+ assert (
36+ copied_inductor_pass .compilation_config .use_inductor_graph_partition
37+ == vllm_config .compilation_config .use_inductor_graph_partition
38+ )
39+ assert (
40+ copied_inductor_pass .compilation_config .splitting_ops
41+ == vllm_config .compilation_config .splitting_ops
42+ )
43+
44+
2845def test_custom_op ():
2946 # proper syntax
3047 _ = CompilationConfig (custom_ops = ["+quant_fp8" , "-silu_and_mul" ])
Original file line number Diff line number Diff line change @@ -285,6 +285,14 @@ def sequence_parallelism_pass_on_test_model(
285285
286286 noop_pass = NoOpEliminationPass (vllm_config )
287287 sequence_parallelism_pass = SequenceParallelismPass (vllm_config )
288+ assert (
289+ sequence_parallelism_pass .compilation_config .splitting_ops
290+ == vllm_config .compilation_config .splitting_ops
291+ )
292+ assert (
293+ sequence_parallelism_pass .compilation_config .use_inductor_graph_partition
294+ == vllm_config .compilation_config .use_inductor_graph_partition
295+ )
288296 func_pass = FixFunctionalizationPass (vllm_config )
289297 cleanup_pass = PostCleanupPass (vllm_config )
290298
Original file line number Diff line number Diff line change 33import functools
44import operator
55import time
6- import weakref
6+ from dataclasses import dataclass
77from typing import ClassVar
88
99import regex as re
1919logger = init_logger (__name__ )
2020
2121
22+ @dataclass
23+ class InductorCompilationConfig :
24+ splitting_ops : list [str ] | None = None
25+ use_inductor_graph_partition : bool = False
26+
27+
2228class VllmInductorPass (InductorPass ):
2329 """
2430 An inductor pass with access to vLLM PassConfig.
@@ -29,7 +35,12 @@ class VllmInductorPass(InductorPass):
2935 """Keep track of pass index for debug dump ordering."""
3036
3137 def __init__ (self , config : VllmConfig ):
32- self .compilation_config = weakref .proxy (config .compilation_config )
38+ # Get only the necessary CompilationConfig for the inductor pass, since
39+ # full `CompilationConfig` contains pointer to model which is unsafe.
40+ self .compilation_config = InductorCompilationConfig (
41+ splitting_ops = config .compilation_config .splitting_ops ,
42+ use_inductor_graph_partition = config .compilation_config .use_inductor_graph_partition ,
43+ )
3344 self .pass_config = config .compilation_config .pass_config
3445 self .model_dtype = config .model_config .dtype if config .model_config else None
3546 self .device = config .device_config .device if config .device_config else None
You can’t perform that action at this time.
0 commit comments