@@ -71,8 +71,8 @@ def __init__(
7171 act_quant_group_shape = group_shape ,
7272 )
7373
74- self .enable_rms_norm = self .norm [0 ].enabled ()
75- self .enable_quant_fp8 = self .fp8_linear .quant_fp8 .enabled ()
74+ self .enable_rms_norm_custom_op = self .norm [0 ].enabled ()
75+ self .enable_quant_fp8_custom_op = self .fp8_linear .quant_fp8 .enabled ()
7676
7777 def forward (self , x ):
7878 # avoid having graph input be an arg to a pattern directly
@@ -107,21 +107,25 @@ def ops_in_model_after(self):
107107 def ops_in_model_before (self ):
108108 return (
109109 [QUANT_OPS [self .quant_key ]]
110- if self .enable_quant_fp8
110+ if self .enable_quant_fp8_custom_op
111111 else [torch .ops .aten .reciprocal ]
112112 )
113113
114114 def ops_in_model_before_partial (self ):
115- return [RMS_OP , RMS_ADD_OP ] if self .enable_rms_norm else [torch .ops .aten .rsqrt ]
115+ return (
116+ [RMS_OP , RMS_ADD_OP ]
117+ if self .enable_rms_norm_custom_op
118+ else [torch .ops .aten .rsqrt ]
119+ )
116120
117121
118122@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
119123@pytest .mark .parametrize ("hidden_size" , [64 ])
120124@pytest .mark .parametrize ("num_tokens" , [257 ])
121125@pytest .mark .parametrize ("eps" , [1e-5 , 1e-6 ])
122126@pytest .mark .parametrize ("static" , [True , False ])
123- @pytest .mark .parametrize ("enable_rms_norm " , [True , False ])
124- @pytest .mark .parametrize ("enable_quant_fp8 " , [True , False ])
127+ @pytest .mark .parametrize ("enable_rms_norm_custom_op " , [True , False ])
128+ @pytest .mark .parametrize ("enable_quant_fp8_custom_op " , [True , False ])
125129# cuda_force_torch used to test torch code path on platforms that
126130# cutlass_fp8_supported() == True.
127131@pytest .mark .parametrize (
@@ -136,8 +140,8 @@ def test_fusion_rmsnorm_quant(
136140 num_tokens ,
137141 eps ,
138142 static ,
139- enable_rms_norm ,
140- enable_quant_fp8 ,
143+ enable_rms_norm_custom_op ,
144+ enable_quant_fp8_custom_op ,
141145 cuda_force_torch ,
142146):
143147 torch .set_default_device ("cuda" )
@@ -146,9 +150,9 @@ def test_fusion_rmsnorm_quant(
146150 maybe_create_device_identity () # needed for certain non-cutlass fp8 paths
147151
148152 custom_ops = []
149- if enable_rms_norm :
153+ if enable_rms_norm_custom_op :
150154 custom_ops .append ("+rms_norm" )
151- if enable_quant_fp8 :
155+ if enable_quant_fp8_custom_op :
152156 custom_ops .append ("+quant_fp8" )
153157 vllm_config = VllmConfig (
154158 model_config = ModelConfig (dtype = dtype ),
@@ -195,7 +199,7 @@ def test_fusion_rmsnorm_quant(
195199 # there's a risk that the fused add doesn't get included in the
196200 # replacement and only the rms part gets fused with quant.
197201 # Hence, we check only 2 add nodes are left (final fused rmsnorm add).
198- if not enable_rms_norm :
202+ if not enable_rms_norm_custom_op :
199203 n_add_nodes = lambda g : sum (1 for _ in find_op_nodes (torch .ops .aten .add , g ))
200204 # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
201205 assert n_add_nodes (backend .graph_pre_pass ) == 7
0 commit comments