Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxscript/rewriter/_fusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Dim = Union[int, ir.SymbolicDim]


def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool:
def check_shape_bool(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool:
if val.shape is None:
return False
if val.shape.rank() != len(shape):
Expand Down
9 changes: 8 additions & 1 deletion onnxscript/rewriter/_pattern_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,14 @@ def __init__(self, value: SupportedAttrTypes):
self._value = value

def matches(self, attr: ir.Attr) -> bool:
return isinstance(attr, ir.Attr) and attr.value == self._value
if attr.type in {
ir.AttributeType.INTS,
ir.AttributeType.FLOATS,
ir.AttributeType.STRINGS,
}:
# Since the type of attr.value is Sequence, we need to convert to the same type for comparison.
return tuple(attr.value) == tuple(self._value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we cast every value in the tuple?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline. Will do when needed

return attr.value == self._value

def __str__(self) -> str:
return str(self._value)
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/rewriter/_rewrite_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool:
if perm.is_ref():
return False
if perm.type == ir.AttributeType.INTS:
if perm.as_ints() == list(range(len(perm.as_ints()))):
if list(perm.as_ints()) == list(range(len(perm.as_ints()))):
return True
return False
"""
Expand Down Expand Up @@ -463,7 +463,7 @@ def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool:
if perm.is_ref():
return False
if perm.type == ir.AttributeType.INTS:
if perm.as_ints() == list(range(len(perm.as_ints()))):
if list(perm.as_ints()) == list(range(len(perm.as_ints()))):
return True
return False

Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/ort_fusions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def check(
self.bindings: dict[str, Dim] = {}

def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
return not _fusion_utils._check_shape(self.bindings, val, dims)
return not _fusion_utils.check_shape_bool(self.bindings, val, dims)

if no_match(input, ["B", "S", "D"]):
return check_result.fail(
Expand Down
6 changes: 3 additions & 3 deletions onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def check(
# Check that last two dimensions are swapped
expected_perm = list(range(len(perm)))
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
if perm != expected_perm:
if list(perm) != expected_perm:
return check_result.fail("Permutation values for Transpose are not correct.")
elif (self._pos == 1 and not _ir_utils.has_rank(x, 2)) or (
self._pos == 2 and not _ir_utils.has_rank(y, 2)
Expand Down Expand Up @@ -188,7 +188,7 @@ def check(
trans_batch_property = "transBatchA" if self._pos == 1 else "transBatchB"
trans_batch = fused_node.attributes.get_int(trans_batch_property, 0)
transposed_node = _get_node(transposed, "Transpose")
perm = transposed_node.attributes["perm"].as_ints()
perm = list(transposed_node.attributes["perm"].as_ints())
if not perm:
return check_result.fail("Permutation values for Transpose are not correct.")

Expand Down Expand Up @@ -296,7 +296,7 @@ def check(self, context, x, y, transposed: ir.Value, **_) -> orp.MatchResult:
if _ir_utils.has_rank(x, 2) and _ir_utils.has_rank(y, 2):
if perm:
# Check that the two dimensions are swapped
if perm != [1, 0]:
if tuple(perm) != (1, 0):
return check_result.fail(
"Permutation values for Transpose are not correct."
)
Expand Down
12 changes: 6 additions & 6 deletions onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def _check_model(
opt = onnx.reference.ReferenceEvaluator(optimized_model, new_ops=[FusedMatMul])
expected = ref.run(None, feeds)
got = opt.run(None, feeds)
self.assertEqual(len(expected), len(got))
self.assertEqual(len(got), len(expected))
for a, b in zip(expected, got):
np.testing.assert_allclose(a, b, atol=atol, rtol=rtol)

Expand Down Expand Up @@ -319,7 +319,7 @@ def test_fused_matmul_div_models(self, name, script_func, input_types, output_ty
rule_set = fused_matmul_rule_sets.fused_matmul_rule_sets()
rule_set.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)
self.assertEqual(["Constant", "FusedMatMul"], [n.op_type for n in ir_model.graph])
self.assertEqual([n.op_type for n in ir_model.graph], ["Constant", "FusedMatMul"])
self._check_model(model_proto, rewritten_model, atol=1e-6)

@parameterized.parameterized.expand(
Expand Down Expand Up @@ -354,7 +354,7 @@ def test_fused_matmul_with_transpose(self, _, script_func):
ir_model = ir.serde.deserialize_model(model_proto)
self._apply_fusion_rules(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)
self.assertEqual(["FusedMatMul"], [n.op_type for n in ir_model.graph])
self.assertEqual([n.op_type for n in ir_model.graph], ["FusedMatMul"])
self._check_model(model_proto, rewritten_model, atol=1e-6)

@parameterized.parameterized.expand([("should_not_match", _should_not_match)])
Expand All @@ -366,8 +366,8 @@ def test_should_not_match(self, _, script_func):
self._apply_fusion_rules(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)
self.assertEqual(
["Transpose", "MatMul", "Transpose"],
[n.op_type for n in ir_model.graph],
["Transpose", "MatMul", "Transpose"],
)
self._check_model(model_proto, rewritten_model, atol=1e-6)

Expand All @@ -391,7 +391,7 @@ def test_fused_matmul_with_other_node_in_middle(self, _, script_func):
common_passes.ShapeInferencePass()(ir_model)
self._apply_fusion_rules(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)
self.assertEqual(["Identity", "FusedMatMul"], [n.op_type for n in ir_model.graph])
self.assertEqual([n.op_type for n in ir_model.graph], ["Identity", "FusedMatMul"])
self._check_model(model_proto, rewritten_model, atol=1e-6)

@parameterized.parameterized.expand(
Expand Down Expand Up @@ -440,7 +440,7 @@ def test_transpose_fused_matmul_with_batch(self, _, script_func):
ir_model = ir.serde.deserialize_model(model_proto)
self._apply_fusion_rules(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)
self.assertEqual(["FusedMatMul"], [n.op_type for n in ir_model.graph])
self.assertEqual([n.op_type for n in ir_model.graph], ["FusedMatMul"])
self._check_model(model_proto, rewritten_model, atol=1e-6)


Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/ort_fusions/gqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def check(
bindings: dict[str, Dim] = {}

def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
return not _fusion_utils._check_shape(bindings, val, dims)
return not _fusion_utils.check_shape_bool(bindings, val, dims)

if no_match(query_BSD, ["B", "S", "D"]):
return False
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def check(
self.bindings: dict[str, Dim] = {}

def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
return not _fusion_utils._check_shape(self.bindings, val, dims)
return not _fusion_utils.check_shape_bool(self.bindings, val, dims)

# Check that if x is being split into q, k, v correctly
# based on hidden sizes
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/ort_fusions/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def check(
bindings: dict[str, Dim] = {}

def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
return not _fusion_utils._check_shape(bindings, val, dims)
return not _fusion_utils.check_shape_bool(bindings, val, dims)

if no_match(query_BSD, ["B", "S", "D"]):
return check_result.fail(
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/ort_fusions/mha_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def check(
self.bindings: dict[str, Dim] = {}

def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
return not _fusion_utils._check_shape(self.bindings, val, dims)
return not _fusion_utils.check_shape_bool(self.bindings, val, dims)

if query_matmul.dtype not in valid_float_types:
return check_result.fail("Query is not a float or float16 type.", query_matmul)
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/rewriter/ort_fusions/skip_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def check(
bindings: dict[str, Dim] = {}

def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
return not _fusion_utils._check_shape(bindings, val, dims)
return not _fusion_utils.check_shape_bool(bindings, val, dims)

if no_match(input, ["B", "S", "D"]):
return check_result.fail(
Expand Down Expand Up @@ -184,7 +184,7 @@ def check(
bindings: dict[str, Dim] = {}

def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
return not _fusion_utils._check_shape(bindings, val, dims)
return not _fusion_utils.check_shape_bool(bindings, val, dims)

if no_match(input, ["B", "S", "D"]):
return check_result.fail(
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ def test_model(x: FLOAT[1024, 512], y: FLOAT[1024, 512]) -> FLOAT[512, 1024]:
function = model.functions[function_id]
self.assertEqual([x.op_type for x in function], ["Add", "Transpose"])
transpose_node = function[1]
self.assertEqual(transpose_node.attributes["perm"].value, [1, 0])
self.assertEqual(list(transpose_node.attributes["perm"].value), [1, 0])
onnxscript.optimizer.inline(model)
self.assertEqual([x.op_type for x in model.graph], ["Add", "Transpose"])

Expand Down
Loading