4
4
5
5
from typing import ClassVar , Optional , Sequence
6
6
7
- from onnxscript .rewriter import _ir_utils
8
7
import onnxscript .rewriter .pattern as orp
9
8
from onnxscript import ir
9
+ from onnxscript .rewriter import _ir_utils
10
10
11
11
12
12
def _get_node (value : ir .Value , name : str ) -> ir .Node :
@@ -22,15 +22,6 @@ def _get_kwargs(node: ir.Node) -> dict[str, float | int]:
22
22
return kwargs
23
23
24
24
25
- def _get_int_or_default (node : ir .Node , name : str , default : int = 0 ) -> int :
26
- """Get the int value from the node attribute dictionary or return default."""
27
- if name in node .attributes :
28
- value = node .attributes [name ].as_int ()
29
- else :
30
- value = default
31
- return value
32
-
33
-
34
25
def _get_ints_or_default (
35
26
node : ir .Node , name : str , default : Optional [Sequence [int ]] = None
36
27
) -> Sequence [int ]:
@@ -103,14 +94,18 @@ def check(
103
94
expected_perm [- 2 ], expected_perm [- 1 ] = expected_perm [- 1 ], expected_perm [- 2 ]
104
95
if perm != expected_perm :
105
96
return check_result .fail ("Permutation values for Transpose are not correct." )
106
- elif not (self ._pos == 1 and _ir_utils .has_rank (x , 2 )) and (self ._pos == 2 and _ir_utils .has_rank (y , 2 )):
97
+ elif (self ._pos == 1 and not _ir_utils .has_rank (x , 2 )) or (
98
+ self ._pos == 2 and not _ir_utils .has_rank (y , 2 )
99
+ ):
107
100
# If perm is not defined, the default transpose behavior is to swap
108
101
# all dimensions, which is correct for MatMul with rank = 2.
109
- return check_result .fail ("Permutation values for Transpose are not correct." )
102
+ return check_result .fail (
103
+ "If perm is not defined, rank must be 2 for TransposeMatMul rule."
104
+ )
110
105
if fused :
111
106
fused_node = _get_node (fused , "FusedMatMul" )
112
107
trans_batch_property = "transBatchA" if self ._pos == 1 else "transBatchB"
113
- if _get_int_or_default ( fused_node , trans_batch_property ):
108
+ if fused_node . attributes . get_int ( trans_batch_property , 0 ):
114
109
return check_result .fail (
115
110
"FusedMatMul with transposed batch cannot be used with op.Transpose in this rule."
116
111
)
@@ -204,7 +199,7 @@ def check(
204
199
check_result = orp .MatchResult ()
205
200
fused_node = _get_node (fused , "FusedMatMul" )
206
201
trans_batch_property = "transBatchA" if self ._pos == 1 else "transBatchB"
207
- trans_batch = _get_int_or_default ( fused_node , trans_batch_property )
202
+ trans_batch = fused_node . attributes . get_int ( trans_batch_property , 0 )
208
203
transposed_node = _get_node (transposed , "Transpose" )
209
204
perm = transposed_node .attributes ["perm" ].as_ints ()
210
205
if not perm :
@@ -312,16 +307,21 @@ def check(self, context, x, y, transposed: ir.Value, **_) -> orp.MatchResult:
312
307
check_result = orp .MatchResult ()
313
308
transpose_node = _get_node (transposed , "Transpose" )
314
309
perm = _get_ints_or_default (transpose_node , "perm" )
315
- if perm :
316
- # Check that last two dimensions are swapped
317
- expected_perm = list (range (len (perm )))
318
- expected_perm [- 2 ], expected_perm [- 1 ] = expected_perm [- 1 ], expected_perm [- 2 ]
319
- if perm != expected_perm :
320
- return check_result .fail ("Permutation values for Transpose are not correct." )
321
- elif not (self ._pos == 1 and _ir_utils .has_rank (x , 2 )) and (self ._pos == 2 and _ir_utils .has_rank (y , 2 )):
322
- # If perm is not defined, the default transpose behavior is to swap
323
- # all dimensions, which is correct for MatMul with rank = 2.
324
- return check_result .fail ("Permutation values for Transpose are not correct." )
310
+ # transA/transB only work on the last two dimensions of the input,
311
+ # so we can only apply this rule if the inputs are rank 2.
312
+ if _ir_utils .has_rank (x , 2 ) and _ir_utils .has_rank (y , 2 ):
313
+ if perm :
314
+ # Check that last two dimensions are swapped
315
+ expected_perm = list (range (len (perm )))
316
+ expected_perm [- 2 ], expected_perm [- 1 ] = expected_perm [- 1 ], expected_perm [- 2 ]
317
+ if perm != expected_perm :
318
+ return check_result .fail (
319
+ "Permutation values for Transpose are not correct."
320
+ )
321
+ # If perm is not defined, the default transpose behavior is to swap
322
+ # all dimensions, which is correct for MatMul with rank = 2.
323
+ else :
324
+ return check_result .fail ("Rank must be 2 for MatMulTranspose rule." )
325
325
return check_result
326
326
327
327
def rewrite (self , op , x , y , fused : ir .Value | None = None , ** _ ):
0 commit comments