Skip to content

Commit 1b1a63e

Browse files
committed
Fix e2e allreduce fusion test
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
1 parent 52f78ce commit 1b1a63e

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

tests/compile/test_fusions_e2e.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class ModelBackendTestCase(NamedTuple):
6969
model_kwargs=dict(max_model_len=1024),
7070
backend=_Backend.TRITON_ATTN,
7171
attention_fusions=0,
72-
allreduce_fusions=64,
72+
allreduce_fusions=65,
7373
),
7474
]
7575

@@ -166,8 +166,7 @@ def test_attn_quant(
166166

167167

168168
# TODO(luka) test both in nightly
169-
# TODO(luka) change to -
170-
CUSTOM_OPS_RMS_NORM = ["+rms_norm"] # , "+rms_norm"]
169+
CUSTOM_OPS_RMS_NORM = ["-rms_norm"] # , "+rms_norm"]
171170

172171

173172
def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
@@ -180,8 +179,11 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
180179
"model_name, model_kwargs, backend, "
181180
"attention_fusions, allreduce_fusions, custom_ops",
182181
# Toggle RMSNorm and QuantFP8 for FP8 models
183-
list(flat_product(MODELS_FP8, ["+quant_fp8,+rms_norm"]))
184-
# custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM))) # TODO
182+
list(
183+
flat_product(
184+
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
185+
)
186+
) # TODO
185187
# Toggle RMSNorm for FP4 models and unquant models
186188
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
187189
)
@@ -245,17 +247,26 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
245247
run_model(
246248
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
247249
)
248-
249-
assert f"Fused quant onto {attention_fusions} attention nodes" in log_holder.text, (
250-
log_holder.text
250+
matches = re.findall(
251+
r"\[compilation/fusion_attn.py:\d+] "
252+
r"Fused quant onto (\d+) attention nodes",
253+
log_holder.text,
251254
)
255+
assert len(matches) == 2, log_holder.text
256+
257+
assert int(matches[0]) == attention_fusions
258+
assert int(matches[1]) == attention_fusions
252259

253260
matches = re.findall(
254-
rf"\[collective_fusion.py:\d+] Replaced {allreduce_fusions} patterns",
261+
r"\[compilation/collective_fusion.py:\d+] "
262+
r"Replaced (\d+) patterns",
255263
log_holder.text,
256264
)
257265
assert len(matches) == 2, log_holder.text
258266

267+
assert int(matches[0]) == allreduce_fusions
268+
assert int(matches[1]) == allreduce_fusions
269+
259270

260271
def run_model(
261272
compile_config: Union[int, CompilationConfig], model: str, **model_kwargs

0 commit comments

Comments
 (0)