Skip to content

Commit 1ae80c6

Browse files
committed
Move global vllm_config to pass manager
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
1 parent b172747 commit 1ae80c6

File tree

2 files changed

+34
-32
lines changed

2 files changed

+34
-32
lines changed

vllm/compilation/fusion.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch._inductor.pattern_matcher import PatternMatcherPass
1010
from torch._ops import OpOverload
1111

12-
from vllm.config import VllmConfig, set_current_vllm_config
12+
from vllm.config import VllmConfig
1313
from vllm.logger import init_logger
1414
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1515
GroupShape,
@@ -334,23 +334,22 @@ def __init__(self, config: VllmConfig):
334334
pass_name="rmsnorm_quant_fusion_pass"
335335
)
336336

337-
with set_current_vllm_config(config, check_compile=False):
338-
for epsilon in [1e-5, 1e-6]:
339-
# Fuse rms_norm + static fp8 quant
340-
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
337+
for epsilon in [1e-5, 1e-6]:
338+
# Fuse rms_norm + static fp8 quant
339+
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
341340

342-
# Fuse fused_add_rms_norm + static fp8 quant
343-
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
344-
self.patterns
345-
)
341+
# Fuse fused_add_rms_norm + static fp8 quant
342+
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
343+
self.patterns
344+
)
346345

347-
# Fuse rms_norm + dynamic per-token fp8 quant
348-
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
346+
# Fuse rms_norm + dynamic per-token fp8 quant
347+
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
349348

350-
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
351-
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
352-
self.patterns
353-
)
349+
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
350+
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
351+
self.patterns
352+
)
354353

355354
self.dump_patterns(config, self.patterns)
356355

vllm/compilation/pass_manager.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch import fx as fx
66

77
from vllm import envs
8-
from vllm.config import VllmConfig
8+
from vllm.config import VllmConfig, set_current_vllm_config
99
from vllm.logger import init_logger
1010
from vllm.platforms import current_platform
1111
from vllm.utils import set_env_var
@@ -86,27 +86,30 @@ def __call__(self, graph: fx.Graph):
8686

8787
def configure(self, config: VllmConfig):
8888
self.pass_config = config.compilation_config.pass_config
89-
if self.pass_config.enable_noop:
90-
self.passes += [NoOpEliminationPass(config)]
9189

92-
if self.pass_config.enable_sequence_parallelism:
93-
self.passes += [SequenceParallelismPass(config)]
94-
if self.pass_config.enable_async_tp:
95-
self.passes += [AsyncTPPass(config)]
90+
# Set the current vllm config to allow tracing CustomOp instances
91+
with set_current_vllm_config(config, check_compile=False):
92+
if self.pass_config.enable_noop:
93+
self.passes += [NoOpEliminationPass(config)]
9694

97-
if self.pass_config.enable_fi_allreduce_fusion:
98-
self.passes += [AllReduceFusionPass(config)]
95+
if self.pass_config.enable_sequence_parallelism:
96+
self.passes += [SequenceParallelismPass(config)]
97+
if self.pass_config.enable_async_tp:
98+
self.passes += [AsyncTPPass(config)]
9999

100-
if self.pass_config.enable_fusion:
101-
self.passes += [RMSNormQuantFusionPass(config)]
102-
self.passes += [ActivationQuantFusionPass(config)]
100+
if self.pass_config.enable_fi_allreduce_fusion:
101+
self.passes += [AllReduceFusionPass(config)]
103102

104-
if self.pass_config.enable_attn_fusion:
105-
self.passes += [AttnFusionPass(config)]
103+
if self.pass_config.enable_fusion:
104+
self.passes += [RMSNormQuantFusionPass(config)]
105+
self.passes += [ActivationQuantFusionPass(config)]
106106

107-
# needs a functional graph
108-
self.post_cleanup = PostCleanupPass(config)
109-
self.fix_functionalization = FixFunctionalizationPass(config)
107+
if self.pass_config.enable_attn_fusion:
108+
self.passes += [AttnFusionPass(config)]
109+
110+
# needs a functional graph
111+
self.post_cleanup = PostCleanupPass(config)
112+
self.fix_functionalization = FixFunctionalizationPass(config)
110113

111114
def add(self, pass_: InductorPass):
112115
assert isinstance(pass_, InductorPass)

0 commit comments

Comments
 (0)