|
5 | 5 | from torch import fx as fx |
6 | 6 |
|
7 | 7 | from vllm import envs |
8 | | -from vllm.config import VllmConfig |
| 8 | +from vllm.config import VllmConfig, set_current_vllm_config |
9 | 9 | from vllm.logger import init_logger |
10 | 10 | from vllm.platforms import current_platform |
11 | 11 | from vllm.utils import set_env_var |
@@ -86,27 +86,30 @@ def __call__(self, graph: fx.Graph): |
86 | 86 |
|
87 | 87 | def configure(self, config: VllmConfig): |
88 | 88 | self.pass_config = config.compilation_config.pass_config |
89 | | - if self.pass_config.enable_noop: |
90 | | - self.passes += [NoOpEliminationPass(config)] |
91 | 89 |
|
92 | | - if self.pass_config.enable_sequence_parallelism: |
93 | | - self.passes += [SequenceParallelismPass(config)] |
94 | | - if self.pass_config.enable_async_tp: |
95 | | - self.passes += [AsyncTPPass(config)] |
| 90 | + # Set the current vllm config to allow tracing CustomOp instances |
| 91 | + with set_current_vllm_config(config, check_compile=False): |
| 92 | + if self.pass_config.enable_noop: |
| 93 | + self.passes += [NoOpEliminationPass(config)] |
96 | 94 |
|
97 | | - if self.pass_config.enable_fi_allreduce_fusion: |
98 | | - self.passes += [AllReduceFusionPass(config)] |
| 95 | + if self.pass_config.enable_sequence_parallelism: |
| 96 | + self.passes += [SequenceParallelismPass(config)] |
| 97 | + if self.pass_config.enable_async_tp: |
| 98 | + self.passes += [AsyncTPPass(config)] |
99 | 99 |
|
100 | | - if self.pass_config.enable_fusion: |
101 | | - self.passes += [RMSNormQuantFusionPass(config)] |
102 | | - self.passes += [ActivationQuantFusionPass(config)] |
| 100 | + if self.pass_config.enable_fi_allreduce_fusion: |
| 101 | + self.passes += [AllReduceFusionPass(config)] |
103 | 102 |
|
104 | | - if self.pass_config.enable_attn_fusion: |
105 | | - self.passes += [AttnFusionPass(config)] |
| 103 | + if self.pass_config.enable_fusion: |
| 104 | + self.passes += [RMSNormQuantFusionPass(config)] |
| 105 | + self.passes += [ActivationQuantFusionPass(config)] |
106 | 106 |
|
107 | | - # needs a functional graph |
108 | | - self.post_cleanup = PostCleanupPass(config) |
109 | | - self.fix_functionalization = FixFunctionalizationPass(config) |
| 107 | + if self.pass_config.enable_attn_fusion: |
| 108 | + self.passes += [AttnFusionPass(config)] |
| 109 | + |
| 110 | + # needs a functional graph |
| 111 | + self.post_cleanup = PostCleanupPass(config) |
| 112 | + self.fix_functionalization = FixFunctionalizationPass(config) |
110 | 113 |
|
111 | 114 | def add(self, pass_: InductorPass): |
112 | 115 | assert isinstance(pass_, InductorPass) |
|
0 commit comments