Skip to content

Commit ccce52e

Browse files
committed
Rewrite tests and address comments
1 parent 73432e5 commit ccce52e

File tree

2 files changed

+363
-641
lines changed

2 files changed

+363
-641
lines changed

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
from typing import ClassVar, Optional, Sequence
66

7-
from onnxscript.rewriter import _ir_utils
87
import onnxscript.rewriter.pattern as orp
98
from onnxscript import ir
9+
from onnxscript.rewriter import _ir_utils
1010

1111

1212
def _get_node(value: ir.Value, name: str) -> ir.Node:
@@ -22,15 +22,6 @@ def _get_kwargs(node: ir.Node) -> dict[str, float | int]:
2222
return kwargs
2323

2424

25-
def _get_int_or_default(node: ir.Node, name: str, default: int = 0) -> int:
26-
"""Get the int value from the node attribute dictionary or return default."""
27-
if name in node.attributes:
28-
value = node.attributes[name].as_int()
29-
else:
30-
value = default
31-
return value
32-
33-
3425
def _get_ints_or_default(
3526
node: ir.Node, name: str, default: Optional[Sequence[int]] = None
3627
) -> Sequence[int]:
@@ -103,14 +94,18 @@ def check(
10394
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
10495
if perm != expected_perm:
10596
return check_result.fail("Permutation values for Transpose are not correct.")
106-
elif not (self._pos == 1 and _ir_utils.has_rank(x, 2)) and (self._pos == 2 and _ir_utils.has_rank(y, 2)):
97+
elif (self._pos == 1 and not _ir_utils.has_rank(x, 2)) or (
98+
self._pos == 2 and not _ir_utils.has_rank(y, 2)
99+
):
107100
# If perm is not defined, the default transpose behavior is to swap
108101
# all dimensions, which is correct for MatMul with rank = 2.
109-
return check_result.fail("Permutation values for Transpose are not correct.")
102+
return check_result.fail(
103+
"If perm is not defined, rank must be 2 for TransposeMatMul rule."
104+
)
110105
if fused:
111106
fused_node = _get_node(fused, "FusedMatMul")
112107
trans_batch_property = "transBatchA" if self._pos == 1 else "transBatchB"
113-
if _get_int_or_default(fused_node, trans_batch_property):
108+
if fused_node.attributes.get_int(trans_batch_property, 0):
114109
return check_result.fail(
115110
"FusedMatMul with transposed batch cannot be used with op.Transpose in this rule."
116111
)
@@ -204,7 +199,7 @@ def check(
204199
check_result = orp.MatchResult()
205200
fused_node = _get_node(fused, "FusedMatMul")
206201
trans_batch_property = "transBatchA" if self._pos == 1 else "transBatchB"
207-
trans_batch = _get_int_or_default(fused_node, trans_batch_property)
202+
trans_batch = fused_node.attributes.get_int(trans_batch_property, 0)
208203
transposed_node = _get_node(transposed, "Transpose")
209204
perm = transposed_node.attributes["perm"].as_ints()
210205
if not perm:
@@ -312,16 +307,21 @@ def check(self, context, x, y, transposed: ir.Value, **_) -> orp.MatchResult:
312307
check_result = orp.MatchResult()
313308
transpose_node = _get_node(transposed, "Transpose")
314309
perm = _get_ints_or_default(transpose_node, "perm")
315-
if perm:
316-
# Check that last two dimensions are swapped
317-
expected_perm = list(range(len(perm)))
318-
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
319-
if perm != expected_perm:
320-
return check_result.fail("Permutation values for Transpose are not correct.")
321-
elif not (self._pos == 1 and _ir_utils.has_rank(x, 2)) and (self._pos == 2 and _ir_utils.has_rank(y, 2)):
322-
# If perm is not defined, the default transpose behavior is to swap
323-
# all dimensions, which is correct for MatMul with rank = 2.
324-
return check_result.fail("Permutation values for Transpose are not correct.")
310+
# transA/transB only work on the last two dimensions of the input,
311+
# so we can only apply this rule if the inputs are rank 2.
312+
if _ir_utils.has_rank(x, 2) and _ir_utils.has_rank(y, 2):
313+
if perm:
314+
# Check that last two dimensions are swapped
315+
expected_perm = list(range(len(perm)))
316+
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
317+
if perm != expected_perm:
318+
return check_result.fail(
319+
"Permutation values for Transpose are not correct."
320+
)
321+
# If perm is not defined, the default transpose behavior is to swap
322+
# all dimensions, which is correct for MatMul with rank = 2.
323+
else:
324+
return check_result.fail("Rank must be 2 for MatMulTranspose rule.")
325325
return check_result
326326

327327
def rewrite(self, op, x, y, fused: ir.Value | None = None, **_):

0 commit comments

Comments
 (0)