11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3- from typing import cast
3+ import itertools
44
55import pytest
66import torch
1616from vllm .compilation .fusion import QUANT_OPS
1717from vllm .compilation .noop_elimination import NoOpEliminationPass
1818from vllm .compilation .post_cleanup import PostCleanupPass
19- from vllm .config import CompilationConfig , PassConfig , VllmConfig
19+ from vllm .config import (
20+ CompilationConfig ,
21+ CompilationMode ,
22+ PassConfig ,
23+ VllmConfig ,
24+ set_current_vllm_config ,
25+ )
2026from vllm .model_executor .layers .activation import SiluAndMul
2127from vllm .model_executor .layers .quantization .utils .quant_utils import (
2228 GroupShape ,
2531)
2632from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
2733 Fp8LinearOp ,
28- cutlass_fp8_supported ,
34+ maybe_create_device_identity ,
2935)
3036from vllm .platforms import current_platform
3137
@@ -54,14 +60,23 @@ def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
5460 act_quant_static = True ,
5561 act_quant_group_shape = GroupShape .PER_TENSOR ,
5662 )
63+ self .enable_silu_mul_custom_op = self .silu_and_mul .enabled ()
64+ self .enable_quant_fp8_custom_op = self .fp8_linear .quant_fp8 .enabled ()
5765
5866 def forward (self , x ):
5967 y = self .silu_and_mul (x )
6068 x2 = self .fp8_linear .apply (y , self .w , self .wscale , input_scale = self .wscale )
6169 return x2
6270
6371 def ops_in_model_before (self ):
64- return [SILU_MUL_OP , QUANT_OPS [kFp8StaticTensorSym ]]
72+ return [
73+ SILU_MUL_OP if self .enable_silu_mul_custom_op else torch .ops .aten .mul ,
74+ (
75+ QUANT_OPS [kFp8StaticTensorSym ]
76+ if self .enable_quant_fp8_custom_op
77+ else torch .ops .aten .reciprocal
78+ ),
79+ ]
6580
6681 def ops_in_model_after (self ):
6782 return [FUSED_OPS [kFp8StaticTensorSym ]]
@@ -77,6 +92,7 @@ def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
7792 assert silu_and_mul_nvfp4_quant_supported
7893
7994 self .silu_and_mul = SiluAndMul ()
95+ self .enable_silu_mul_custom_op = self .silu_and_mul .enabled ()
8096
8197 # create nvfp4 weight
8298 w = torch .rand ((hidden_size , hidden_size ))
@@ -101,7 +117,10 @@ def forward(self, x):
101117 return out
102118
103119 def ops_in_model_before (self ):
104- return [SILU_MUL_OP , QUANT_OPS [kNvfp4Quant ]]
120+ return [
121+ SILU_MUL_OP if self .enable_silu_mul_custom_op else torch .ops .aten .mul ,
122+ QUANT_OPS [kNvfp4Quant ],
123+ ]
105124
106125 def ops_in_model_after (self ):
107126 return [FUSED_OPS [kNvfp4Quant ]]
@@ -110,67 +129,80 @@ def ops_in_model_after(self):
110129@pytest .mark .parametrize ("num_tokens" , [32 , 64 ])
111130@pytest .mark .parametrize ("hidden_size" , [128 , 256 ])
112131@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float16 ])
132+ @pytest .mark .parametrize ("enable_silu_mul_custom_op" , [True , False ])
113133@pytest .mark .parametrize (
114- "model_class" ,
115- cast (
116- list [type ],
117- [TestSiluMulFp8QuantModel , TestSiluMulNvfp4QuantModel ]
118- if is_nvfp4_supported ()
119- else [TestSiluMulFp8QuantModel ],
120- ),
134+ "model_class, enable_quant_fp8_custom_op, cuda_force_torch" ,
135+ list (itertools .product ([TestSiluMulFp8QuantModel ], [True , False ], [True , False ]))
136+ + [(TestSiluMulNvfp4QuantModel , False , False )],
121137)
122138# cuda_force_torch used to test torch code path on platforms that
123139# cutlass_fp8_supported() == True.
124- @pytest .mark .parametrize (
125- "cuda_force_torch" , [True , False ] if cutlass_fp8_supported () else [True ]
126- )
127140@pytest .mark .skipif (
128141 envs .VLLM_TARGET_DEVICE not in ["cuda" , "rocm" ], reason = "Only test on CUDA and ROCm"
129142)
130143def test_fusion_silu_and_mul_quant (
131- num_tokens , hidden_size , dtype , model_class , cuda_force_torch
144+ num_tokens : int ,
145+ hidden_size : int ,
146+ dtype : torch .dtype ,
147+ model_class : type [TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel ],
148+ enable_silu_mul_custom_op : bool ,
149+ enable_quant_fp8_custom_op : bool ,
150+ cuda_force_torch : bool ,
132151):
133- if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch :
134- pytest .skip ("Duplicate tests for NVFP4 " )
152+ if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported () :
153+ pytest .skip ("NVFP4 is not supported on this GPU. " )
135154
136155 torch .set_default_device ("cuda" )
137156 torch .set_default_dtype (dtype )
157+ maybe_create_device_identity ()
138158
139159 x = torch .rand (num_tokens , hidden_size * 2 )
140160
141161 # Reshape pass is needed for the fusion pass to work
142- config = VllmConfig ()
143- config .compilation_config = CompilationConfig (
144- pass_config = PassConfig (enable_fusion = True , enable_noop = True )
162+ custom_ops = []
163+ if enable_silu_mul_custom_op :
164+ custom_ops .append ("+silu_and_mul" )
165+ if enable_quant_fp8_custom_op :
166+ custom_ops .append ("+quant_fp8" )
167+ config = VllmConfig (
168+ compilation_config = CompilationConfig (
169+ mode = CompilationMode .VLLM_COMPILE ,
170+ custom_ops = custom_ops ,
171+ pass_config = PassConfig (enable_fusion = True , enable_noop = True ),
172+ ),
145173 )
146- fusion_pass = ActivationQuantFusionPass (config )
147174
148- passes = [NoOpEliminationPass (config ), fusion_pass , PostCleanupPass (config )]
149- backend = TestBackend (* passes )
150- model = model_class (hidden_size = hidden_size , cuda_force_torch = cuda_force_torch , x = x )
175+ with set_current_vllm_config (config ):
176+ fusion_pass = ActivationQuantFusionPass (config )
151177
152- # First dimension dynamic
153- torch ._dynamo .mark_dynamic (x , 0 )
178+ passes = [NoOpEliminationPass (config ), fusion_pass , PostCleanupPass (config )]
179+ backend = TestBackend (* passes )
180+ model = model_class (
181+ hidden_size = hidden_size , cuda_force_torch = cuda_force_torch , x = x
182+ )
154183
155- result = model (x )
184+ # First dimension dynamic
185+ torch ._dynamo .mark_dynamic (x , 0 )
156186
157- model2 = torch .compile (model , backend = backend )
158- result2 = model2 (x )
187+ result = model (x )
159188
160- # Check that it gives the same answer
161- if model_class == TestSiluMulFp8QuantModel :
162- atol , rtol = 1e-3 , 1e-3
163- elif model_class == TestSiluMulNvfp4QuantModel :
164- atol , rtol = 1e-1 , 1e-1
189+ model2 = torch .compile (model , backend = backend )
190+ result2 = model2 (x )
165191
166- torch .testing .assert_close (
167- result [0 ].to (dtype = dtype ), result2 [0 ].to (dtype = dtype ), atol = atol , rtol = rtol
168- )
192+ # Check that it gives the same answer
193+ if model_class == TestSiluMulFp8QuantModel :
194+ atol , rtol = 1e-3 , 1e-3
195+ elif model_class == TestSiluMulNvfp4QuantModel :
196+ atol , rtol = 1e-1 , 1e-1
197+
198+ torch .testing .assert_close (
199+ result [0 ].to (dtype = dtype ), result2 [0 ].to (dtype = dtype ), atol = atol , rtol = rtol
200+ )
169201
170- assert fusion_pass .matched_count == 1
202+ assert fusion_pass .matched_count == 1
171203
172- # In pre-nodes, quant op should be present and fused kernels should not
173- backend .check_before_ops (model .ops_in_model_before ())
204+ # In pre-nodes, quant op should be present and fused kernels should not
205+ backend .check_before_ops (model .ops_in_model_before ())
174206
175- # In post-nodes, fused kernels should be present and quant op should not
176- backend .check_after_ops (model .ops_in_model_after ())
207+ # In post-nodes, fused kernels should be present and quant op should not
208+ backend .check_after_ops (model .ops_in_model_after ())
0 commit comments