@@ -69,7 +69,7 @@ class ModelBackendTestCase(NamedTuple):
6969 model_kwargs = dict (max_model_len = 1024 ),
7070 backend = _Backend .TRITON_ATTN ,
7171 attention_fusions = 0 ,
72- allreduce_fusions = 64 ,
72+ allreduce_fusions = 65 ,
7373 ),
7474 ]
7575
@@ -166,8 +166,7 @@ def test_attn_quant(
166166
167167
168168# TODO(luka) test both in nightly
169- # TODO(luka) change to -
170- CUSTOM_OPS_RMS_NORM = ["+rms_norm" ] # , "+rms_norm"]
169+ CUSTOM_OPS_RMS_NORM = ["-rms_norm" ] # , "+rms_norm"]
171170
172171
173172def custom_ops_product (* custom_ops_lists : list [str ]) -> Iterable [str ]:
@@ -180,8 +179,11 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
180179 "model_name, model_kwargs, backend, "
181180 "attention_fusions, allreduce_fusions, custom_ops" ,
182181 # Toggle RMSNorm and QuantFP8 for FP8 models
183- list (flat_product (MODELS_FP8 , ["+quant_fp8,+rms_norm" ]))
184- # custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM))) # TODO
182+ list (
183+ flat_product (
184+ MODELS_FP8 , custom_ops_product (CUSTOM_OPS_FP8 , CUSTOM_OPS_RMS_NORM )
185+ )
186+ ) # TODO
185187 # Toggle RMSNorm for FP4 models and unquant models
186188 + list (flat_product (MODELS_FP4 + MODELS , CUSTOM_OPS_RMS_NORM )),
187189)
@@ -245,17 +247,26 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
245247 run_model (
246248 compilation_config , model_name , tensor_parallel_size = 2 , ** model_kwargs
247249 )
248-
249- assert f"Fused quant onto { attention_fusions } attention nodes" in log_holder .text , (
250- log_holder .text
250+ matches = re .findall (
251+ r"\[compilation/fusion_attn.py:\d+] "
252+ r"Fused quant onto (\d+) attention nodes" ,
253+ log_holder .text ,
251254 )
255+ assert len (matches ) == 2 , log_holder .text
256+
257+ assert int (matches [0 ]) == attention_fusions
258+ assert int (matches [1 ]) == attention_fusions
252259
253260 matches = re .findall (
254- rf"\[collective_fusion.py:\d+] Replaced { allreduce_fusions } patterns" ,
261+ r"\[compilation/collective_fusion.py:\d+] "
262+ r"Replaced (\d+) patterns" ,
255263 log_holder .text ,
256264 )
257265 assert len (matches ) == 2 , log_holder .text
258266
267+ assert int (matches [0 ]) == allreduce_fusions
268+ assert int (matches [1 ]) == allreduce_fusions
269+
259270
260271def run_model (
261272 compile_config : Union [int , CompilationConfig ], model : str , ** model_kwargs
0 commit comments