Skip to content

Commit aaa47a6

Browse files
ZJY05160xrushi
authored andcommitted
[torch.compile] Enable silu_mul_fp8_quant fusion without custom ops enabled (vllm-project#27146)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
1 parent 61e6f00 commit aaa47a6

File tree

4 files changed

+128
-70
lines changed

4 files changed

+128
-70
lines changed

tests/compile/test_silu_mul_quant_fusion.py

Lines changed: 75 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from typing import cast
3+
import itertools
44

55
import pytest
66
import torch
@@ -16,7 +16,13 @@
1616
from vllm.compilation.fusion import QUANT_OPS
1717
from vllm.compilation.noop_elimination import NoOpEliminationPass
1818
from vllm.compilation.post_cleanup import PostCleanupPass
19-
from vllm.config import CompilationConfig, PassConfig, VllmConfig
19+
from vllm.config import (
20+
CompilationConfig,
21+
CompilationMode,
22+
PassConfig,
23+
VllmConfig,
24+
set_current_vllm_config,
25+
)
2026
from vllm.model_executor.layers.activation import SiluAndMul
2127
from vllm.model_executor.layers.quantization.utils.quant_utils import (
2228
GroupShape,
@@ -25,7 +31,7 @@
2531
)
2632
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
2733
Fp8LinearOp,
28-
cutlass_fp8_supported,
34+
maybe_create_device_identity,
2935
)
3036
from vllm.platforms import current_platform
3137

@@ -54,14 +60,23 @@ def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
5460
act_quant_static=True,
5561
act_quant_group_shape=GroupShape.PER_TENSOR,
5662
)
63+
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
64+
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
5765

5866
def forward(self, x):
5967
y = self.silu_and_mul(x)
6068
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
6169
return x2
6270

6371
def ops_in_model_before(self):
64-
return [SILU_MUL_OP, QUANT_OPS[kFp8StaticTensorSym]]
72+
return [
73+
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
74+
(
75+
QUANT_OPS[kFp8StaticTensorSym]
76+
if self.enable_quant_fp8_custom_op
77+
else torch.ops.aten.reciprocal
78+
),
79+
]
6580

6681
def ops_in_model_after(self):
6782
return [FUSED_OPS[kFp8StaticTensorSym]]
@@ -77,6 +92,7 @@ def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
7792
assert silu_and_mul_nvfp4_quant_supported
7893

7994
self.silu_and_mul = SiluAndMul()
95+
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
8096

8197
# create nvfp4 weight
8298
w = torch.rand((hidden_size, hidden_size))
@@ -101,7 +117,10 @@ def forward(self, x):
101117
return out
102118

103119
def ops_in_model_before(self):
104-
return [SILU_MUL_OP, QUANT_OPS[kNvfp4Quant]]
120+
return [
121+
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
122+
QUANT_OPS[kNvfp4Quant],
123+
]
105124

106125
def ops_in_model_after(self):
107126
return [FUSED_OPS[kNvfp4Quant]]
@@ -110,67 +129,80 @@ def ops_in_model_after(self):
110129
@pytest.mark.parametrize("num_tokens", [32, 64])
111130
@pytest.mark.parametrize("hidden_size", [128, 256])
112131
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
132+
@pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False])
113133
@pytest.mark.parametrize(
114-
"model_class",
115-
cast(
116-
list[type],
117-
[TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
118-
if is_nvfp4_supported()
119-
else [TestSiluMulFp8QuantModel],
120-
),
134+
"model_class, enable_quant_fp8_custom_op, cuda_force_torch",
135+
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
136+
+ [(TestSiluMulNvfp4QuantModel, False, False)],
121137
)
122138
# cuda_force_torch used to test torch code path on platforms that
123139
# cutlass_fp8_supported() == True.
124-
@pytest.mark.parametrize(
125-
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
126-
)
127140
@pytest.mark.skipif(
128141
envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm"
129142
)
130143
def test_fusion_silu_and_mul_quant(
131-
num_tokens, hidden_size, dtype, model_class, cuda_force_torch
144+
num_tokens: int,
145+
hidden_size: int,
146+
dtype: torch.dtype,
147+
model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel],
148+
enable_silu_mul_custom_op: bool,
149+
enable_quant_fp8_custom_op: bool,
150+
cuda_force_torch: bool,
132151
):
133-
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
134-
pytest.skip("Duplicate tests for NVFP4")
152+
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
153+
pytest.skip("NVFP4 is not supported on this GPU.")
135154

136155
torch.set_default_device("cuda")
137156
torch.set_default_dtype(dtype)
157+
maybe_create_device_identity()
138158

139159
x = torch.rand(num_tokens, hidden_size * 2)
140160

141161
# Reshape pass is needed for the fusion pass to work
142-
config = VllmConfig()
143-
config.compilation_config = CompilationConfig(
144-
pass_config=PassConfig(enable_fusion=True, enable_noop=True)
162+
custom_ops = []
163+
if enable_silu_mul_custom_op:
164+
custom_ops.append("+silu_and_mul")
165+
if enable_quant_fp8_custom_op:
166+
custom_ops.append("+quant_fp8")
167+
config = VllmConfig(
168+
compilation_config=CompilationConfig(
169+
mode=CompilationMode.VLLM_COMPILE,
170+
custom_ops=custom_ops,
171+
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
172+
),
145173
)
146-
fusion_pass = ActivationQuantFusionPass(config)
147174

148-
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
149-
backend = TestBackend(*passes)
150-
model = model_class(hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x)
175+
with set_current_vllm_config(config):
176+
fusion_pass = ActivationQuantFusionPass(config)
151177

152-
# First dimension dynamic
153-
torch._dynamo.mark_dynamic(x, 0)
178+
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
179+
backend = TestBackend(*passes)
180+
model = model_class(
181+
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
182+
)
154183

155-
result = model(x)
184+
# First dimension dynamic
185+
torch._dynamo.mark_dynamic(x, 0)
156186

157-
model2 = torch.compile(model, backend=backend)
158-
result2 = model2(x)
187+
result = model(x)
159188

160-
# Check that it gives the same answer
161-
if model_class == TestSiluMulFp8QuantModel:
162-
atol, rtol = 1e-3, 1e-3
163-
elif model_class == TestSiluMulNvfp4QuantModel:
164-
atol, rtol = 1e-1, 1e-1
189+
model2 = torch.compile(model, backend=backend)
190+
result2 = model2(x)
165191

166-
torch.testing.assert_close(
167-
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
168-
)
192+
# Check that it gives the same answer
193+
if model_class == TestSiluMulFp8QuantModel:
194+
atol, rtol = 1e-3, 1e-3
195+
elif model_class == TestSiluMulNvfp4QuantModel:
196+
atol, rtol = 1e-1, 1e-1
197+
198+
torch.testing.assert_close(
199+
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
200+
)
169201

170-
assert fusion_pass.matched_count == 1
202+
assert fusion_pass.matched_count == 1
171203

172-
# In pre-nodes, quant op should be present and fused kernels should not
173-
backend.check_before_ops(model.ops_in_model_before())
204+
# In pre-nodes, quant op should be present and fused kernels should not
205+
backend.check_before_ops(model.ops_in_model_before())
174206

175-
# In post-nodes, fused kernels should be present and quant op should not
176-
backend.check_after_ops(model.ops_in_model_after())
207+
# In post-nodes, fused kernels should be present and quant op should not
208+
backend.check_after_ops(model.ops_in_model_after())

vllm/compilation/activation_quant_fusion.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
QuantKey,
1919
kFp8StaticTensorSym,
2020
kNvfp4Quant,
21-
kStaticTensorScale,
2221
)
2322
from vllm.platforms import current_platform
2423

2524
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
2625
from .inductor_pass import enable_fake_mode
26+
from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul
2727
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
2828

2929
logger = init_logger(__name__)
@@ -66,6 +66,8 @@ def __init__(
6666
)
6767
self.FUSED_OP = FUSED_OPS[self.quant_key]
6868

69+
self.silu_and_mul_matcher = MatcherSiluAndMul()
70+
6971
def empty_quant(self, *args, **kwargs):
7072
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
7173
return torch.empty(*args, **kwargs)
@@ -80,42 +82,38 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
8082
Fusion for SiluMul+Fp8StaticQuant Pattern
8183
"""
8284

83-
def __init__(self, symmetric: bool = True):
84-
quant_key = QuantKey(
85-
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
86-
)
87-
super().__init__(quant_key)
85+
def __init__(self):
86+
super().__init__(kFp8StaticTensorSym)
87+
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
8888

8989
def register(self, pm_pass: PatternMatcherPass):
9090
def pattern(
91-
result: torch.Tensor,
92-
result_silu_mul: torch.Tensor,
9391
input: torch.Tensor,
9492
scale: torch.Tensor,
9593
):
96-
at1 = auto_functionalized(SILU_MUL_OP, result=result_silu_mul, input=input)
97-
at2 = auto_functionalized(
98-
self.QUANT_OP, result=result, input=at1[1], scale=scale
99-
)
100-
return at2[1]
94+
result_silu_mul = self.silu_and_mul_matcher(input)
95+
result_quant = self.quant_matcher(result_silu_mul, scale)
96+
return result_quant[0]
10197

10298
def replacement(
103-
result: torch.Tensor,
104-
result_silu_mul: torch.Tensor,
10599
input: torch.Tensor,
106100
scale: torch.Tensor,
107101
):
102+
d = input.shape[-1] // 2
103+
output_shape = input.shape[:-1] + (d,)
104+
result = torch.empty(
105+
output_shape, device=input.device, dtype=self.quant_dtype
106+
)
108107
at = auto_functionalized(
109108
self.FUSED_OP, result=result, input=input, scale=scale
110109
)
111110
return at[1]
112111

113112
inputs = [
114-
self.empty_quant(5, 4), # result
115-
empty_bf16(5, 4), # result_silu_mul
116-
empty_bf16(5, 4), # input
117-
empty_fp32(1, 1), # scale
113+
*self.silu_and_mul_matcher.inputs(), # input
114+
self.quant_matcher.inputs()[1], # scale
118115
]
116+
pattern(*inputs)
119117

120118
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
121119

@@ -132,24 +130,22 @@ def register(self, pm_pass: PatternMatcherPass):
132130
def pattern(
133131
result: torch.Tensor,
134132
output_scale: torch.Tensor,
135-
result_silu_mul: torch.Tensor,
136133
input: torch.Tensor,
137134
scale: torch.Tensor,
138135
):
139-
at1 = auto_functionalized(SILU_MUL_OP, result=result_silu_mul, input=input)
140-
at2 = auto_functionalized(
136+
result_silu_mul = self.silu_and_mul_matcher(input)
137+
at = auto_functionalized(
141138
self.QUANT_OP,
142139
output=result,
143-
input=at1[1],
140+
input=result_silu_mul,
144141
output_scale=output_scale,
145142
input_scale=scale,
146143
)
147-
return at2[1], at2[2]
144+
return at[1], at[2]
148145

149146
def replacement(
150147
result: torch.Tensor,
151148
output_scale: torch.Tensor,
152-
result_silu_mul: torch.Tensor,
153149
input: torch.Tensor,
154150
scale: torch.Tensor,
155151
):
@@ -165,7 +161,6 @@ def replacement(
165161
inputs = [
166162
self.empty_quant(5, 32), # result
167163
empty_i32(128, 4), # output_scale
168-
empty_bf16(5, 64), # result_silu_mul
169164
empty_bf16(5, 64), # input
170165
empty_fp32(1, 1), # scale
171166
]

vllm/compilation/matcher_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch._ops import OpOverload
88

99
from vllm.config import get_current_vllm_config
10+
from vllm.model_executor.layers.activation import SiluAndMul
1011
from vllm.model_executor.layers.layernorm import RMSNorm
1112
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
1213
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@@ -31,6 +32,8 @@
3132
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
3233
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
3334

35+
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
36+
3437

3538
class MatcherCustomOp(ABC):
3639
def __init__(self, enabled: bool):
@@ -206,3 +209,30 @@ def inputs(self) -> list[torch.Tensor]:
206209
return [input, self.empty_f32(1, 1)]
207210

208211
return [input]
212+
213+
214+
class MatcherSiluAndMul(MatcherCustomOp):
215+
def __init__(self, enabled: bool | None = None):
216+
if enabled is None:
217+
enabled = SiluAndMul.enabled()
218+
super().__init__(enabled)
219+
220+
def inputs(self) -> list[torch.Tensor]:
221+
input = self.empty(5, 4)
222+
return [input]
223+
224+
def forward_custom(
225+
self,
226+
x: torch.Tensor,
227+
) -> torch.Tensor:
228+
d = x.shape[-1] // 2
229+
output_shape = x.shape[:-1] + (d,)
230+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
231+
result = auto_functionalized(SILU_MUL_OP, result=out, input=x)
232+
return result[1]
233+
234+
def forward_native(
235+
self,
236+
x: torch.Tensor,
237+
) -> torch.Tensor:
238+
return SiluAndMul.forward_native(x)

vllm/model_executor/layers/activation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def __init__(self):
8080
elif current_platform.is_cpu():
8181
self._forward_method = self.forward_native
8282

83-
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
83+
@staticmethod
84+
def forward_native(x: torch.Tensor) -> torch.Tensor:
8485
"""PyTorch-native implementation equivalent to forward()."""
8586
d = x.shape[-1] // 2
8687
return F.silu(x[..., :d]) * x[..., d:]

0 commit comments

Comments
 (0)