Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa
from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu
from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa
from onnxscript.rewriter.ort_fusions.mha import fuse_mha
from onnxscript.rewriter.ort_fusions.mha import fuse_mha1, fuse_mha2
from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization
from onnxscript.rewriter.ort_fusions.rotary_embedding import (
fuse_partial_rotary_embedding,
Expand Down Expand Up @@ -87,8 +87,9 @@ def fuse(func, apply_shape_inference: bool = False):
# in the rewrite rule for certain patterns of SDPA.
fusion_count["sdpa"] = fuse(fuse_sdpa, apply_shape_inference=True)
# Optimize to avoid trying multiple attention-based fusions
fusion_count["mha"] = fuse(fuse_mha)
if fusion_count["mha"] == 0:
fusion_count["mha1"] = fuse(fuse_mha1)
fusion_count["mha2"] = fuse(fuse_mha2)
if (fusion_count["mha1"] == 0) and (fusion_count["mha2"] == 0):
# If no MHA fusion was applied, we can try the GQA fusion.
# and avoid trying the attention fusion.
fusion_count["gqa"] = fuse(fuse_gqa)
Expand Down
3 changes: 2 additions & 1 deletion onnxscript/rewriter/ort_fusions/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def test_whisper_encoder(self):
sdpa_count = xformers.fuse_sdpa(model)
self.assertGreater(sdpa_count, 0)
model = shape_inference.infer_shapes(model)
mha_count = xformers.fuse_mha(model)
mha_count = xformers.fuse_mha1(model)
mha_count += xformers.fuse_mha2(model)
self.assertGreater(mha_count, 0)
fused_mha_bias_count = xformers.fuse_mha_bias(model)
self.assertGreater(fused_mha_bias_count, 0)
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/ort_fusions/fuse_xformers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_fuse_xformers(self):
self.assertEqual(fusion_count["partial_rotary_embedding"], 0)
self.assertEqual(fusion_count["cos_sin_cache"], 2)
self.assertEqual(fusion_count["sdpa"], 1)
self.assertEqual(fusion_count["mha"], 1)
self.assertEqual(fusion_count["mha1"] + fusion_count["mha2"], 1)
self.assertEqual(fusion_count["attention"], 0)
self.assertEqual(fusion_count["gqa"], 0)
self.assertEqual(fusion_count["gelu"], 0)
Expand Down
83 changes: 44 additions & 39 deletions onnxscript/rewriter/ort_fusions/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,45 +376,50 @@ def rewrite(
)


parameter_combinations = [
{
"double_transpose": double_transpose,
"transpose_4d": transpose_4d,
"pre_scale_q": pre_scale_q,
"is_rotary": is_rotary,
"use_mask": use_mask,
"has_past_present": has_past_present,
"is_cross_attention": is_cross_attention,
}
for double_transpose in [False, True]
for transpose_4d in (
[False, True] if double_transpose else [False]
) # Only generate patterns when double_transpose is True
for pre_scale_q in [True, False]
for is_rotary in [False, True]
for use_mask in [False, True]
for is_cross_attention in [False, True]
for has_past_present in ([False] if is_cross_attention else [True, False])
# Skip if both has_past_present and is_cross_attention are True
if not (has_past_present and is_cross_attention)
]

# Dynamically create the rules
mha_rules = pattern.RewriteRuleSet(
[
MultiHeadAttention.rule(
f"MHA_{'4D' if params['transpose_4d'] else '3D'}_Transpose"
f"{'_Twice' if params['double_transpose'] else ''}"
f"{'_PreScaleQ' if params['pre_scale_q'] else ''}"
f"{'_Rotary' if params['is_rotary'] else ''}"
f"{'_Masked' if params['use_mask'] else ''}"
f"{'_Past' if params['has_past_present'] else ''}"
f"{'_CrossAttention' if params['is_cross_attention'] else ''}",
**params,
)
for params in parameter_combinations
def _make_rule_set(has_past_present: bool):
parameter_combinations = [
{
"double_transpose": double_transpose,
"transpose_4d": transpose_4d,
"pre_scale_q": pre_scale_q,
"is_rotary": is_rotary,
"use_mask": use_mask,
"has_past_present": has_past_present,
"is_cross_attention": is_cross_attention,
}
for double_transpose in [False, True]
for transpose_4d in (
[False, True] if double_transpose else [False]
) # Only generate patterns when double_transpose is True
for pre_scale_q in [True, False]
for is_rotary in [False, True]
for use_mask in [False, True]
for is_cross_attention in ([False] if has_past_present else [False, True])
]
)

# Dynamically create the rules
mha_rules = pattern.RewriteRuleSet(
[
MultiHeadAttention.rule(
f"MHA_{'4D' if params['transpose_4d'] else '3D'}_Transpose"
f"{'_Twice' if params['double_transpose'] else ''}"
f"{'_PreScaleQ' if params['pre_scale_q'] else ''}"
f"{'_Rotary' if params['is_rotary'] else ''}"
f"{'_Masked' if params['use_mask'] else ''}"
f"{'_Past' if params['has_past_present'] else ''}"
f"{'_CrossAttention' if params['is_cross_attention'] else ''}",
**params,
)
for params in parameter_combinations
]
)

return mha_rules


mha_rules_no_past = _make_rule_set(has_past_present=False)
mha_rules_with_past = _make_rule_set(has_past_present=True)

fuse_mha = _fusion_utils.apply_fusion_rules(mha_rules)
# Try rules with past first, and then rules without past.
fuse_mha1 = _fusion_utils.apply_fusion_rules(mha_rules_with_past)
fuse_mha2 = _fusion_utils.apply_fusion_rules(mha_rules_no_past)
9 changes: 6 additions & 3 deletions onnxscript/rewriter/ort_fusions/mha_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def test_smollm(self):
# Fuse SDPA and MHA
sdpa_count = xformers.fuse_sdpa(model)
self.assertGreater(sdpa_count, 0)
mha_count = xformers.fuse_mha(model)
mha_count = xformers.fuse_mha1(model)
mha_count += xformers.fuse_mha2(model)
self.assertGreater(mha_count, 0)

if test_with_ort:
Expand All @@ -59,7 +60,8 @@ def test_whisper_encoder(self):
sdpa_count = xformers.fuse_sdpa(model)
self.assertGreater(sdpa_count, 0)
model = shape_inference.infer_shapes(model)
mha_count = xformers.fuse_mha(model)
mha_count = xformers.fuse_mha1(model)
mha_count += xformers.fuse_mha2(model)
self.assertGreater(mha_count, 0)
onnxscript.optimizer.optimize(model)

Expand All @@ -84,7 +86,8 @@ def test_whisper_decoder(self):
sdpa_count = xformers.fuse_sdpa(model)
self.assertGreater(sdpa_count, 0)
model = shape_inference.infer_shapes(model)
mha_count = xformers.fuse_mha(model)
mha_count = xformers.fuse_mha1(model)
mha_count += xformers.fuse_mha2(model)
self.assertGreater(mha_count, 0)
onnxscript.optimizer.optimize(model)

Expand Down
Loading