diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index 5320cd5896..79de57f335 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -20,7 +20,7 @@ from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa -from onnxscript.rewriter.ort_fusions.mha import fuse_mha +from onnxscript.rewriter.ort_fusions.mha import fuse_mha1, fuse_mha2 from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization from onnxscript.rewriter.ort_fusions.rotary_embedding import ( fuse_partial_rotary_embedding, @@ -87,8 +87,9 @@ def fuse(func, apply_shape_inference: bool = False): # in the rewrite rule for certain patterns of SDPA. fusion_count["sdpa"] = fuse(fuse_sdpa, apply_shape_inference=True) # Optimize to avoid trying multiple attention-based fusions - fusion_count["mha"] = fuse(fuse_mha) - if fusion_count["mha"] == 0: + fusion_count["mha1"] = fuse(fuse_mha1) + fusion_count["mha2"] = fuse(fuse_mha2) + if (fusion_count["mha1"] == 0) and (fusion_count["mha2"] == 0): # If no MHA fusion was applied, we can try the GQA fusion. # and avoid trying the attention fusion. fusion_count["gqa"] = fuse(fuse_gqa) diff --git a/onnxscript/rewriter/ort_fusions/attention_test.py b/onnxscript/rewriter/ort_fusions/attention_test.py index aaedc3fc0a..fa62badf86 100644 --- a/onnxscript/rewriter/ort_fusions/attention_test.py +++ b/onnxscript/rewriter/ort_fusions/attention_test.py @@ -173,7 +173,8 @@ def test_whisper_encoder(self): sdpa_count = xformers.fuse_sdpa(model) self.assertGreater(sdpa_count, 0) model = shape_inference.infer_shapes(model) - mha_count = xformers.fuse_mha(model) + mha_count = xformers.fuse_mha1(model) + mha_count += xformers.fuse_mha2(model) self.assertGreater(mha_count, 0) fused_mha_bias_count = xformers.fuse_mha_bias(model) self.assertGreater(fused_mha_bias_count, 0) diff --git a/onnxscript/rewriter/ort_fusions/fuse_xformers_test.py b/onnxscript/rewriter/ort_fusions/fuse_xformers_test.py index 4c9c2ea416..d03093b346 100644 --- a/onnxscript/rewriter/ort_fusions/fuse_xformers_test.py +++ b/onnxscript/rewriter/ort_fusions/fuse_xformers_test.py @@ -27,7 +27,7 @@ def test_fuse_xformers(self): self.assertEqual(fusion_count["partial_rotary_embedding"], 0) self.assertEqual(fusion_count["cos_sin_cache"], 2) self.assertEqual(fusion_count["sdpa"], 1) - self.assertEqual(fusion_count["mha"], 1) + self.assertEqual(fusion_count["mha1"] + fusion_count["mha2"], 1) self.assertEqual(fusion_count["attention"], 0) self.assertEqual(fusion_count["gqa"], 0) self.assertEqual(fusion_count["gelu"], 0) diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index 0985d5be23..ea9ac6932f 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -376,45 +376,50 @@ def rewrite( ) -parameter_combinations = [ - { - "double_transpose": double_transpose, - "transpose_4d": transpose_4d, - "pre_scale_q": pre_scale_q, - "is_rotary": is_rotary, - "use_mask": use_mask, - "has_past_present": has_past_present, - "is_cross_attention": is_cross_attention, - } - for double_transpose in [False, True] - for transpose_4d in ( - [False, True] if double_transpose else [False] - ) # Only generate patterns when double_transpose is True - for pre_scale_q in [True, False] - for is_rotary in [False, True] - for use_mask in [False, True] - for is_cross_attention in [False, True] - for has_past_present in ([False] if is_cross_attention else [True, False]) - # Skip if both has_past_present and is_cross_attention are True - if not (has_past_present and is_cross_attention) -] - -# Dynamically create the rules -mha_rules = pattern.RewriteRuleSet( - [ - MultiHeadAttention.rule( - f"MHA_{'4D' if params['transpose_4d'] else '3D'}_Transpose" - f"{'_Twice' if params['double_transpose'] else ''}" - f"{'_PreScaleQ' if params['pre_scale_q'] else ''}" - f"{'_Rotary' if params['is_rotary'] else ''}" - f"{'_Masked' if params['use_mask'] else ''}" - f"{'_Past' if params['has_past_present'] else ''}" - f"{'_CrossAttention' if params['is_cross_attention'] else ''}", - **params, - ) - for params in parameter_combinations +def _make_rule_set(has_past_present: bool): + parameter_combinations = [ + { + "double_transpose": double_transpose, + "transpose_4d": transpose_4d, + "pre_scale_q": pre_scale_q, + "is_rotary": is_rotary, + "use_mask": use_mask, + "has_past_present": has_past_present, + "is_cross_attention": is_cross_attention, + } + for double_transpose in [False, True] + for transpose_4d in ( + [False, True] if double_transpose else [False] + ) # Only generate patterns when double_transpose is True + for pre_scale_q in [True, False] + for is_rotary in [False, True] + for use_mask in [False, True] + for is_cross_attention in ([False] if has_past_present else [False, True]) ] -) + # Dynamically create the rules + mha_rules = pattern.RewriteRuleSet( + [ + MultiHeadAttention.rule( + f"MHA_{'4D' if params['transpose_4d'] else '3D'}_Transpose" + f"{'_Twice' if params['double_transpose'] else ''}" + f"{'_PreScaleQ' if params['pre_scale_q'] else ''}" + f"{'_Rotary' if params['is_rotary'] else ''}" + f"{'_Masked' if params['use_mask'] else ''}" + f"{'_Past' if params['has_past_present'] else ''}" + f"{'_CrossAttention' if params['is_cross_attention'] else ''}", + **params, + ) + for params in parameter_combinations + ] + ) + + return mha_rules + + +mha_rules_no_past = _make_rule_set(has_past_present=False) +mha_rules_with_past = _make_rule_set(has_past_present=True) -fuse_mha = _fusion_utils.apply_fusion_rules(mha_rules) +# Try rules with past first, and then rules without past. +fuse_mha1 = _fusion_utils.apply_fusion_rules(mha_rules_with_past) +fuse_mha2 = _fusion_utils.apply_fusion_rules(mha_rules_no_past) diff --git a/onnxscript/rewriter/ort_fusions/mha_test.py b/onnxscript/rewriter/ort_fusions/mha_test.py index 8f4ed9715e..e7efb9c978 100644 --- a/onnxscript/rewriter/ort_fusions/mha_test.py +++ b/onnxscript/rewriter/ort_fusions/mha_test.py @@ -35,7 +35,8 @@ def test_smollm(self): # Fuse SDPA and MHA sdpa_count = xformers.fuse_sdpa(model) self.assertGreater(sdpa_count, 0) - mha_count = xformers.fuse_mha(model) + mha_count = xformers.fuse_mha1(model) + mha_count += xformers.fuse_mha2(model) self.assertGreater(mha_count, 0) if test_with_ort: @@ -59,7 +60,8 @@ def test_whisper_encoder(self): sdpa_count = xformers.fuse_sdpa(model) self.assertGreater(sdpa_count, 0) model = shape_inference.infer_shapes(model) - mha_count = xformers.fuse_mha(model) + mha_count = xformers.fuse_mha1(model) + mha_count += xformers.fuse_mha2(model) self.assertGreater(mha_count, 0) onnxscript.optimizer.optimize(model) @@ -84,7 +86,8 @@ def test_whisper_decoder(self): sdpa_count = xformers.fuse_sdpa(model) self.assertGreater(sdpa_count, 0) model = shape_inference.infer_shapes(model) - mha_count = xformers.fuse_mha(model) + mha_count = xformers.fuse_mha1(model) + mha_count += xformers.fuse_mha2(model) self.assertGreater(mha_count, 0) onnxscript.optimizer.optimize(model)