11# SPDX-License-Identifier: Apache-2.0 
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 
3- from   typing   import  cast 
3+ import  itertools 
44
55import  pytest 
66import  torch 
1616from  vllm .compilation .fusion  import  QUANT_OPS 
1717from  vllm .compilation .noop_elimination  import  NoOpEliminationPass 
1818from  vllm .compilation .post_cleanup  import  PostCleanupPass 
19- from  vllm .config  import  CompilationConfig , PassConfig , VllmConfig 
19+ from  vllm .config  import  (
20+     CompilationConfig ,
21+     CompilationMode ,
22+     PassConfig ,
23+     VllmConfig ,
24+     set_current_vllm_config ,
25+ )
2026from  vllm .model_executor .layers .activation  import  SiluAndMul 
2127from  vllm .model_executor .layers .quantization .utils .quant_utils  import  (
2228    GroupShape ,
2531)
2632from  vllm .model_executor .layers .quantization .utils .w8a8_utils  import  (
2733    Fp8LinearOp ,
28-     cutlass_fp8_supported ,
34+     maybe_create_device_identity ,
2935)
3036from  vllm .platforms  import  current_platform 
3137
@@ -54,14 +60,23 @@ def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
5460                act_quant_static = True ,
5561                act_quant_group_shape = GroupShape .PER_TENSOR ,
5662            )
63+         self .enable_silu_mul_custom_op  =  self .silu_and_mul .enabled ()
64+         self .enable_quant_fp8_custom_op  =  self .fp8_linear .quant_fp8 .enabled ()
5765
5866    def  forward (self , x ):
5967        y  =  self .silu_and_mul (x )
6068        x2  =  self .fp8_linear .apply (y , self .w , self .wscale , input_scale = self .wscale )
6169        return  x2 
6270
6371    def  ops_in_model_before (self ):
64-         return  [SILU_MUL_OP , QUANT_OPS [kFp8StaticTensorSym ]]
72+         return  [
73+             SILU_MUL_OP  if  self .enable_silu_mul_custom_op  else  torch .ops .aten .mul ,
74+             (
75+                 QUANT_OPS [kFp8StaticTensorSym ]
76+                 if  self .enable_quant_fp8_custom_op 
77+                 else  torch .ops .aten .reciprocal 
78+             ),
79+         ]
6580
6681    def  ops_in_model_after (self ):
6782        return  [FUSED_OPS [kFp8StaticTensorSym ]]
@@ -77,6 +92,7 @@ def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
7792        assert  silu_and_mul_nvfp4_quant_supported 
7893
7994        self .silu_and_mul  =  SiluAndMul ()
95+         self .enable_silu_mul_custom_op  =  self .silu_and_mul .enabled ()
8096
8197        # create nvfp4 weight 
8298        w  =  torch .rand ((hidden_size , hidden_size ))
@@ -101,7 +117,10 @@ def forward(self, x):
101117        return  out 
102118
103119    def  ops_in_model_before (self ):
104-         return  [SILU_MUL_OP , QUANT_OPS [kNvfp4Quant ]]
120+         return  [
121+             SILU_MUL_OP  if  self .enable_silu_mul_custom_op  else  torch .ops .aten .mul ,
122+             QUANT_OPS [kNvfp4Quant ],
123+         ]
105124
106125    def  ops_in_model_after (self ):
107126        return  [FUSED_OPS [kNvfp4Quant ]]
@@ -110,67 +129,80 @@ def ops_in_model_after(self):
110129@pytest .mark .parametrize ("num_tokens" , [32 , 64 ]) 
111130@pytest .mark .parametrize ("hidden_size" , [128 , 256 ]) 
112131@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float16 ]) 
132+ @pytest .mark .parametrize ("enable_silu_mul_custom_op" , [True , False ]) 
113133@pytest .mark .parametrize ( 
114-     "model_class" , 
115-     cast ( 
116-         list [type ], 
117-         [TestSiluMulFp8QuantModel , TestSiluMulNvfp4QuantModel ] 
118-         if  is_nvfp4_supported () 
119-         else  [TestSiluMulFp8QuantModel ], 
120-     ), 
134+     "model_class, enable_quant_fp8_custom_op, cuda_force_torch" , 
135+     list (itertools .product ([TestSiluMulFp8QuantModel ], [True , False ], [True , False ])) 
136+     +  [(TestSiluMulNvfp4QuantModel , False , False )], 
121137) 
122138# cuda_force_torch used to test torch code path on platforms that 
123139# cutlass_fp8_supported() == True. 
124- @pytest .mark .parametrize ( 
125-     "cuda_force_torch" , [True , False ] if  cutlass_fp8_supported () else  [True ] 
126- ) 
127140@pytest .mark .skipif ( 
128141    envs .VLLM_TARGET_DEVICE  not  in   ["cuda" , "rocm" ], reason = "Only test on CUDA and ROCm"  
129142) 
130143def  test_fusion_silu_and_mul_quant (
131-     num_tokens , hidden_size , dtype , model_class , cuda_force_torch 
144+     num_tokens : int ,
145+     hidden_size : int ,
146+     dtype : torch .dtype ,
147+     model_class : type [TestSiluMulFp8QuantModel  |  TestSiluMulNvfp4QuantModel ],
148+     enable_silu_mul_custom_op : bool ,
149+     enable_quant_fp8_custom_op : bool ,
150+     cuda_force_torch : bool ,
132151):
133-     if  model_class  ==  TestSiluMulNvfp4QuantModel  and  cuda_force_torch :
134-         pytest .skip ("Duplicate tests for NVFP4 " )
152+     if  model_class  is  TestSiluMulNvfp4QuantModel  and  not   is_nvfp4_supported () :
153+         pytest .skip ("NVFP4 is not supported on this GPU. " )
135154
136155    torch .set_default_device ("cuda" )
137156    torch .set_default_dtype (dtype )
157+     maybe_create_device_identity ()
138158
139159    x  =  torch .rand (num_tokens , hidden_size  *  2 )
140160
141161    # Reshape pass is needed for the fusion pass to work 
142-     config  =  VllmConfig ()
143-     config .compilation_config  =  CompilationConfig (
144-         pass_config = PassConfig (enable_fusion = True , enable_noop = True )
162+     custom_ops  =  []
163+     if  enable_silu_mul_custom_op :
164+         custom_ops .append ("+silu_and_mul" )
165+     if  enable_quant_fp8_custom_op :
166+         custom_ops .append ("+quant_fp8" )
167+     config  =  VllmConfig (
168+         compilation_config = CompilationConfig (
169+             mode = CompilationMode .VLLM_COMPILE ,
170+             custom_ops = custom_ops ,
171+             pass_config = PassConfig (enable_fusion = True , enable_noop = True ),
172+         ),
145173    )
146-     fusion_pass  =  ActivationQuantFusionPass (config )
147174
148-     passes  =  [NoOpEliminationPass (config ), fusion_pass , PostCleanupPass (config )]
149-     backend  =  TestBackend (* passes )
150-     model  =  model_class (hidden_size = hidden_size , cuda_force_torch = cuda_force_torch , x = x )
175+     with  set_current_vllm_config (config ):
176+         fusion_pass  =  ActivationQuantFusionPass (config )
151177
152-     # First dimension dynamic 
153-     torch ._dynamo .mark_dynamic (x , 0 )
178+         passes  =  [NoOpEliminationPass (config ), fusion_pass , PostCleanupPass (config )]
179+         backend  =  TestBackend (* passes )
180+         model  =  model_class (
181+             hidden_size = hidden_size , cuda_force_torch = cuda_force_torch , x = x 
182+         )
154183
155-     result  =  model (x )
184+         # First dimension dynamic 
185+         torch ._dynamo .mark_dynamic (x , 0 )
156186
157-     model2  =  torch .compile (model , backend = backend )
158-     result2  =  model2 (x )
187+         result  =  model (x )
159188
160-     # Check that it gives the same answer 
161-     if  model_class  ==  TestSiluMulFp8QuantModel :
162-         atol , rtol  =  1e-3 , 1e-3 
163-     elif  model_class  ==  TestSiluMulNvfp4QuantModel :
164-         atol , rtol  =  1e-1 , 1e-1 
189+         model2  =  torch .compile (model , backend = backend )
190+         result2  =  model2 (x )
165191
166-     torch .testing .assert_close (
167-         result [0 ].to (dtype = dtype ), result2 [0 ].to (dtype = dtype ), atol = atol , rtol = rtol 
168-     )
192+         # Check that it gives the same answer 
193+         if  model_class  ==  TestSiluMulFp8QuantModel :
194+             atol , rtol  =  1e-3 , 1e-3 
195+         elif  model_class  ==  TestSiluMulNvfp4QuantModel :
196+             atol , rtol  =  1e-1 , 1e-1 
197+ 
198+         torch .testing .assert_close (
199+             result [0 ].to (dtype = dtype ), result2 [0 ].to (dtype = dtype ), atol = atol , rtol = rtol 
200+         )
169201
170-     assert  fusion_pass .matched_count  ==  1 
202+          assert  fusion_pass .matched_count  ==  1 
171203
172-     # In pre-nodes, quant op should be present and fused kernels should not 
173-     backend .check_before_ops (model .ops_in_model_before ())
204+          # In pre-nodes, quant op should be present and fused kernels should not 
205+          backend .check_before_ops (model .ops_in_model_before ())
174206
175-     # In post-nodes, fused kernels should be present and quant op should not 
176-     backend .check_after_ops (model .ops_in_model_after ())
207+          # In post-nodes, fused kernels should be present and quant op should not 
208+          backend .check_after_ops (model .ops_in_model_after ())
0 commit comments