From 038e8a840dc7511c199c4103bb756babcc231646 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 11 Aug 2025 10:55:15 -0700 Subject: [PATCH] Remove double transpose flag Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/mha.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index 433c10e504..e2987cfc5e 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -37,13 +37,11 @@ def __init__( self, name, *, - double_transpose: bool, is_rotary: bool, has_past_present: bool, is_cross_attention: bool, ): super().__init__(name) - self._double_transpose = double_transpose self._is_rotary = is_rotary self._has_past_present = has_past_present self._is_cross_attention = is_cross_attention @@ -345,12 +343,10 @@ def rewrite( def _make_rule_set(has_past_present: bool): parameter_combinations = [ { - "double_transpose": double_transpose, "is_rotary": is_rotary, "has_past_present": has_past_present, "is_cross_attention": is_cross_attention, } - for double_transpose in [False, True] for is_rotary in [False, True] for is_cross_attention in ([False] if has_past_present else [False, True]) ] @@ -360,7 +356,6 @@ def _make_rule_set(has_past_present: bool): [ MultiHeadAttention.rule( f"MHA" - f"{'_Twice' if params['double_transpose'] else ''}" f"{'_Rotary' if params['is_rotary'] else ''}" f"{'_Past' if params['has_past_present'] else ''}" f"{'_CrossAttention' if params['is_cross_attention'] else ''}",