| 
1 | 1 | # SPDX-License-Identifier: Apache-2.0  | 
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project  | 
3 | 3 | 
 
  | 
 | 4 | +from contextlib import ExitStack  | 
 | 5 | + | 
4 | 6 | from torch import fx as fx  | 
5 | 7 | 
 
  | 
 | 8 | +from vllm import envs  | 
6 | 9 | from vllm.config import VllmConfig  | 
7 | 10 | from vllm.logger import init_logger  | 
8 | 11 | from vllm.platforms import current_platform  | 
 | 12 | +from vllm.utils import set_env_var  | 
9 | 13 | 
 
  | 
10 | 14 | if current_platform.is_cuda_alike():  | 
11 | 15 |     from .fusion import FusionPass  | 
@@ -43,13 +47,20 @@ def __init__(self):  | 
43 | 47 |         self.passes: list[VllmInductorPass] = []  | 
44 | 48 | 
 
  | 
45 | 49 |     def __call__(self, graph: fx.Graph):  | 
46 |  | -        shape = get_pass_context().runtime_shape  | 
47 |  | -        for pass_ in self.passes:  | 
48 |  | -            if pass_.is_applicable_for_shape(shape):  | 
49 |  | -                pass_(graph)  | 
50 |  | - | 
51 |  | -        # always run fix_functionalization last  | 
52 |  | -        self.fix_functionalization(graph)  | 
 | 50 | +        with ExitStack() as stack:  | 
 | 51 | +            if envs.VLLM_PATTERN_MATCH_DEBUG is not None:  | 
 | 52 | +                # and get_tensor_model_parallel_rank() == 0:  | 
 | 53 | +                stack.enter_context(  | 
 | 54 | +                    set_env_var('TORCHINDUCTOR_PATTERN_MATCH_DEBUG',  | 
 | 55 | +                                envs.VLLM_PATTERN_MATCH_DEBUG))  | 
 | 56 | + | 
 | 57 | +            shape = get_pass_context().runtime_shape  | 
 | 58 | +            for pass_ in self.passes:  | 
 | 59 | +                if pass_.is_applicable_for_shape(shape):  | 
 | 60 | +                    pass_(graph)  | 
 | 61 | + | 
 | 62 | +            # always run fix_functionalization last  | 
 | 63 | +            self.fix_functionalization(graph)  | 
53 | 64 | 
 
  | 
54 | 65 |     def configure(self, config: VllmConfig):  | 
55 | 66 |         self.pass_config = config.compilation_config.pass_config  | 
 | 
0 commit comments