From 17f6fca42ffb231c557360d27afd91b10a69a083 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 22 May 2025 18:39:30 -0700 Subject: [PATCH 1/4] Ensure MHA rule ordering Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/_core.py | 7 +- onnxscript/rewriter/ort_fusions/mha.py | 83 ++++++++++++------------ 2 files changed, 47 insertions(+), 43 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index 5320cd5896..79de57f335 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -20,7 +20,7 @@ from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa -from onnxscript.rewriter.ort_fusions.mha import fuse_mha +from onnxscript.rewriter.ort_fusions.mha import fuse_mha1, fuse_mha2 from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization from onnxscript.rewriter.ort_fusions.rotary_embedding import ( fuse_partial_rotary_embedding, @@ -87,8 +87,9 @@ def fuse(func, apply_shape_inference: bool = False): # in the rewrite rule for certain patterns of SDPA. fusion_count["sdpa"] = fuse(fuse_sdpa, apply_shape_inference=True) # Optimize to avoid trying multiple attention-based fusions - fusion_count["mha"] = fuse(fuse_mha) - if fusion_count["mha"] == 0: + fusion_count["mha1"] = fuse(fuse_mha1) + fusion_count["mha2"] = fuse(fuse_mha2) + if (fusion_count["mha1"] == 0) and (fusion_count["mha2"] == 0): # If no MHA fusion was applied, we can try the GQA fusion. # and avoid trying the attention fusion. fusion_count["gqa"] = fuse(fuse_gqa) diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index 0985d5be23..1e30a2f7a7 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -375,46 +375,49 @@ def rewrite( _outputs=num_outputs, ) - -parameter_combinations = [ - { - "double_transpose": double_transpose, - "transpose_4d": transpose_4d, - "pre_scale_q": pre_scale_q, - "is_rotary": is_rotary, - "use_mask": use_mask, - "has_past_present": has_past_present, - "is_cross_attention": is_cross_attention, - } - for double_transpose in [False, True] - for transpose_4d in ( - [False, True] if double_transpose else [False] - ) # Only generate patterns when double_transpose is True - for pre_scale_q in [True, False] - for is_rotary in [False, True] - for use_mask in [False, True] - for is_cross_attention in [False, True] - for has_past_present in ([False] if is_cross_attention else [True, False]) - # Skip if both has_past_present and is_cross_attention are True - if not (has_past_present and is_cross_attention) -] - -# Dynamically create the rules -mha_rules = pattern.RewriteRuleSet( - [ - MultiHeadAttention.rule( - f"MHA_{'4D' if params['transpose_4d'] else '3D'}_Transpose" - f"{'_Twice' if params['double_transpose'] else ''}" - f"{'_PreScaleQ' if params['pre_scale_q'] else ''}" - f"{'_Rotary' if params['is_rotary'] else ''}" - f"{'_Masked' if params['use_mask'] else ''}" - f"{'_Past' if params['has_past_present'] else ''}" - f"{'_CrossAttention' if params['is_cross_attention'] else ''}", - **params, - ) - for params in parameter_combinations +def _make_rule_set(has_past_present:bool): + parameter_combinations = [ + { + "double_transpose": double_transpose, + "transpose_4d": transpose_4d, + "pre_scale_q": pre_scale_q, + "is_rotary": is_rotary, + "use_mask": use_mask, + "has_past_present": has_past_present, + "is_cross_attention": is_cross_attention, + } + for double_transpose in [False, True] + for transpose_4d in ( + [False, True] if double_transpose else [False] + ) # Only generate patterns when double_transpose is True + for pre_scale_q in [True, False] + for is_rotary in [False, True] + for use_mask in [False, True] + for is_cross_attention in [False] if has_past_present else [False, True] ] -) + # Dynamically create the rules + mha_rules = pattern.RewriteRuleSet( + [ + MultiHeadAttention.rule( + f"MHA_{'4D' if params['transpose_4d'] else '3D'}_Transpose" + f"{'_Twice' if params['double_transpose'] else ''}" + f"{'_PreScaleQ' if params['pre_scale_q'] else ''}" + f"{'_Rotary' if params['is_rotary'] else ''}" + f"{'_Masked' if params['use_mask'] else ''}" + f"{'_Past' if params['has_past_present'] else ''}" + f"{'_CrossAttention' if params['is_cross_attention'] else ''}", + **params, + ) + for params in parameter_combinations + ] + ) + + return mha_rules + +mha_rules_no_past = _make_rule_set(has_past_present=False) +mha_rules_with_past = _make_rule_set(has_past_present=True) -fuse_mha = _fusion_utils.apply_fusion_rules(mha_rules) +# Try rules with past first, and then rules without past. +fuse_mha1 = _fusion_utils.apply_fusion_rules(mha_rules_with_past) +fuse_mha2 = _fusion_utils.apply_fusion_rules(mha_rules_no_past) From 230b4fd30ddbb8550ad0a8dfb8a81d1a58fb63b7 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 22 May 2025 18:49:34 -0700 Subject: [PATCH 2/4] Add parentheses Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/mha.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index 1e30a2f7a7..ea9ac6932f 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -375,7 +375,8 @@ def rewrite( _outputs=num_outputs, ) -def _make_rule_set(has_past_present:bool): + +def _make_rule_set(has_past_present: bool): parameter_combinations = [ { "double_transpose": double_transpose, @@ -393,7 +394,7 @@ def _make_rule_set(has_past_present:bool): for pre_scale_q in [True, False] for is_rotary in [False, True] for use_mask in [False, True] - for is_cross_attention in [False] if has_past_present else [False, True] + for is_cross_attention in ([False] if has_past_present else [False, True]) ] # Dynamically create the rules @@ -415,6 +416,7 @@ def _make_rule_set(has_past_present:bool): return mha_rules + mha_rules_no_past = _make_rule_set(has_past_present=False) mha_rules_with_past = _make_rule_set(has_past_present=True) From ed2131613e16e3760762b12a751d9950845ea9ad Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 22 May 2025 19:21:42 -0700 Subject: [PATCH 3/4] Fix references to fuse_mha Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/attention_test.py | 3 ++- onnxscript/rewriter/ort_fusions/mha_test.py | 9 ++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/attention_test.py b/onnxscript/rewriter/ort_fusions/attention_test.py index aaedc3fc0a..fa62badf86 100644 --- a/onnxscript/rewriter/ort_fusions/attention_test.py +++ b/onnxscript/rewriter/ort_fusions/attention_test.py @@ -173,7 +173,8 @@ def test_whisper_encoder(self): sdpa_count = xformers.fuse_sdpa(model) self.assertGreater(sdpa_count, 0) model = shape_inference.infer_shapes(model) - mha_count = xformers.fuse_mha(model) + mha_count = xformers.fuse_mha1(model) + mha_count += xformers.fuse_mha2(model) self.assertGreater(mha_count, 0) fused_mha_bias_count = xformers.fuse_mha_bias(model) self.assertGreater(fused_mha_bias_count, 0) diff --git a/onnxscript/rewriter/ort_fusions/mha_test.py b/onnxscript/rewriter/ort_fusions/mha_test.py index 8f4ed9715e..e7efb9c978 100644 --- a/onnxscript/rewriter/ort_fusions/mha_test.py +++ b/onnxscript/rewriter/ort_fusions/mha_test.py @@ -35,7 +35,8 @@ def test_smollm(self): # Fuse SDPA and MHA sdpa_count = xformers.fuse_sdpa(model) self.assertGreater(sdpa_count, 0) - mha_count = xformers.fuse_mha(model) + mha_count = xformers.fuse_mha1(model) + mha_count += xformers.fuse_mha2(model) self.assertGreater(mha_count, 0) if test_with_ort: @@ -59,7 +60,8 @@ def test_whisper_encoder(self): sdpa_count = xformers.fuse_sdpa(model) self.assertGreater(sdpa_count, 0) model = shape_inference.infer_shapes(model) - mha_count = xformers.fuse_mha(model) + mha_count = xformers.fuse_mha1(model) + mha_count += xformers.fuse_mha2(model) self.assertGreater(mha_count, 0) onnxscript.optimizer.optimize(model) @@ -84,7 +86,8 @@ def test_whisper_decoder(self): sdpa_count = xformers.fuse_sdpa(model) self.assertGreater(sdpa_count, 0) model = shape_inference.infer_shapes(model) - mha_count = xformers.fuse_mha(model) + mha_count = xformers.fuse_mha1(model) + mha_count += xformers.fuse_mha2(model) self.assertGreater(mha_count, 0) onnxscript.optimizer.optimize(model) From 841dd99451a75cb4e56a79d117f48c1f0647a631 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 22 May 2025 19:41:08 -0700 Subject: [PATCH 4/4] Fix count of mha fusion Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/fuse_xformers_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/ort_fusions/fuse_xformers_test.py b/onnxscript/rewriter/ort_fusions/fuse_xformers_test.py index 4c9c2ea416..d03093b346 100644 --- a/onnxscript/rewriter/ort_fusions/fuse_xformers_test.py +++ b/onnxscript/rewriter/ort_fusions/fuse_xformers_test.py @@ -27,7 +27,7 @@ def test_fuse_xformers(self): self.assertEqual(fusion_count["partial_rotary_embedding"], 0) self.assertEqual(fusion_count["cos_sin_cache"], 2) self.assertEqual(fusion_count["sdpa"], 1) - self.assertEqual(fusion_count["mha"], 1) + self.assertEqual(fusion_count["mha1"] + fusion_count["mha2"], 1) self.assertEqual(fusion_count["attention"], 0) self.assertEqual(fusion_count["gqa"], 0) self.assertEqual(fusion_count["gelu"], 0)