@@ -66,8 +66,9 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
6666 def forward (self , hidden_states , residual ):
6767 view = hidden_states .reshape (- 1 , self .hidden_size )
6868 all_reduce = tensor_model_parallel_all_reduce (view )
69- norm , _ = self .norm (all_reduce , residual )
70- return norm
69+ norm , res = self .norm (all_reduce , residual )
70+
71+ return norm , res
7172
7273 def ops_in_model_before (self ):
7374 return [torch .ops .vllm .all_reduce .default ]
@@ -98,7 +99,9 @@ def ops_in_model_after(self):
9899 def ops_in_model_before (self ):
99100 return [
100101 torch .ops .vllm .all_reduce .default ,
101- torch .ops ._C .static_scaled_fp8_quant .default ,
102+ torch .ops ._C .static_scaled_fp8_quant .default
103+ if self .quant_fp8 .enabled ()
104+ else torch .ops .aten .reciprocal .default ,
102105 ]
103106
104107
@@ -139,19 +142,21 @@ def ops_in_model_before(self):
139142
140143@multi_gpu_test (num_gpus = 2 )
141144@pytest .mark .parametrize (
142- "test_model" ,
145+ "test_model, enable_quant_fp8 " ,
143146 [
144- TestAllReduceRMSNormModel ,
145- TestAllReduceFusedAddRMSNormModel ,
146- TestAllReduceFusedAddRMSNormStaticQuantFP8Model ,
147+ (TestAllReduceRMSNormModel , False ),
148+ (TestAllReduceFusedAddRMSNormModel , False ),
149+ (TestAllReduceFusedAddRMSNormStaticQuantFP8Model , True ),
150+ (TestAllReduceFusedAddRMSNormStaticQuantFP8Model , False ),
147151 # TODO: Enable with torch==2.8.0
148- # TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
152+ # ( TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False) ,
149153 ],
150154)
151155@pytest .mark .parametrize ("batch_size" , [8 ])
152156@pytest .mark .parametrize ("seq_len" , [8 ])
153157@pytest .mark .parametrize ("hidden_size" , [16 ])
154158@pytest .mark .parametrize ("dtype" , [torch .bfloat16 ])
159+ @pytest .mark .parametrize ("enable_rms_norm" , [True , False ])
155160@pytest .mark .skipif (envs .VLLM_TARGET_DEVICE not in ["cuda" ], reason = "Only test on CUDA" )
156161@pytest .mark .skipif (
157162 not find_spec ("flashinfer" )
@@ -165,6 +170,8 @@ def test_all_reduce_fusion_pass_replace(
165170 seq_len : int ,
166171 hidden_size : int ,
167172 dtype : torch .dtype ,
173+ enable_rms_norm ,
174+ enable_quant_fp8 ,
168175):
169176 num_processes = 2
170177 if (
@@ -179,7 +186,16 @@ def test_all_reduce_fusion_pass_replace(
179186 def run_torch_spawn (fn , nprocs ):
180187 torch .multiprocessing .spawn (
181188 fn ,
182- args = (num_processes , test_model , batch_size , seq_len , hidden_size , dtype ),
189+ args = (
190+ num_processes ,
191+ test_model ,
192+ batch_size ,
193+ seq_len ,
194+ hidden_size ,
195+ dtype ,
196+ enable_rms_norm ,
197+ enable_quant_fp8 ,
198+ ),
183199 nprocs = nprocs ,
184200 )
185201
@@ -194,6 +210,8 @@ def all_reduce_fusion_pass_on_test_model(
194210 seq_len : int ,
195211 hidden_size : int ,
196212 dtype : torch .dtype ,
213+ enable_rms_norm ,
214+ enable_quant_fp8 ,
197215):
198216 current_platform .seed_everything (0 )
199217
@@ -215,9 +233,15 @@ def all_reduce_fusion_pass_on_test_model(
215233 init_distributed_environment ()
216234 initialize_model_parallel (tensor_model_parallel_size = world_size )
217235
236+ custom_ops = []
237+ if enable_rms_norm :
238+ custom_ops .append ("+rms_norm" )
239+ if enable_quant_fp8 :
240+ custom_ops .append ("+quant_fp8" )
241+
218242 vllm_config = VllmConfig (
219243 compilation_config = CompilationConfig (
220- level = CompilationLevel .PIECEWISE , custom_ops = [ "+rms_norm" , "+quant_fp8" ]
244+ level = CompilationLevel .PIECEWISE , custom_ops = custom_ops
221245 )
222246 )
223247 vllm_config .compilation_config .pass_config = PassConfig (
@@ -239,7 +263,7 @@ def all_reduce_fusion_pass_on_test_model(
239263 cleanup_pass = PostCleanupPass (vllm_config )
240264
241265 backend = TestBackend (
242- all_reduce_fusion_pass , noop_pass , func_pass , cleanup_pass
266+ noop_pass , all_reduce_fusion_pass , func_pass , cleanup_pass
243267 )
244268
245269 token_num = batch_size * seq_len
0 commit comments