diff --git a/onnxscript/rewriter/_fusion_utils.py b/onnxscript/rewriter/_fusion_utils.py index dbf16ae3d3..f6a7204ac8 100644 --- a/onnxscript/rewriter/_fusion_utils.py +++ b/onnxscript/rewriter/_fusion_utils.py @@ -13,7 +13,7 @@ Dim = Union[int, ir.SymbolicDim] -def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool: +def check_shape_bool(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool: if val.shape is None: return False if val.shape.rank() != len(shape): diff --git a/onnxscript/rewriter/_pattern_ir.py b/onnxscript/rewriter/_pattern_ir.py index f64d3fca3c..9b81e33581 100644 --- a/onnxscript/rewriter/_pattern_ir.py +++ b/onnxscript/rewriter/_pattern_ir.py @@ -126,7 +126,14 @@ def __init__(self, value: SupportedAttrTypes): self._value = value def matches(self, attr: ir.Attr) -> bool: - return isinstance(attr, ir.Attr) and attr.value == self._value + if attr.type in { + ir.AttributeType.INTS, + ir.AttributeType.FLOATS, + ir.AttributeType.STRINGS, + }: + # Since the type of attr.value is Sequence, we need to convert to the same type for comparison. + return tuple(attr.value) == tuple(self._value) + return attr.value == self._value def __str__(self) -> str: return str(self._value) diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index 9481ca5077..af0165dea0 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -392,7 +392,7 @@ def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool: if perm.is_ref(): return False if perm.type == ir.AttributeType.INTS: - if perm.as_ints() == list(range(len(perm.as_ints()))): + if list(perm.as_ints()) == list(range(len(perm.as_ints()))): return True return False """ @@ -463,7 +463,7 @@ def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool: if perm.is_ref(): return False if perm.type == ir.AttributeType.INTS: - if perm.as_ints() == list(range(len(perm.as_ints()))): + if list(perm.as_ints()) == list(range(len(perm.as_ints()))): return True return False diff --git a/onnxscript/rewriter/ort_fusions/attention.py b/onnxscript/rewriter/ort_fusions/attention.py index 4a4cd0ad8e..ce234bbb63 100644 --- a/onnxscript/rewriter/ort_fusions/attention.py +++ b/onnxscript/rewriter/ort_fusions/attention.py @@ -160,7 +160,7 @@ def check( self.bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(self.bindings, val, dims) + return not _fusion_utils.check_shape_bool(self.bindings, val, dims) if no_match(input, ["B", "S", "D"]): return check_result.fail( diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py index 5082c20464..cdc50c99ae 100644 --- a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py @@ -79,7 +79,7 @@ def check( # Check that last two dimensions are swapped expected_perm = list(range(len(perm))) expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] - if perm != expected_perm: + if list(perm) != expected_perm: return check_result.fail("Permutation values for Transpose are not correct.") elif (self._pos == 1 and not _ir_utils.has_rank(x, 2)) or ( self._pos == 2 and not _ir_utils.has_rank(y, 2) @@ -188,7 +188,7 @@ def check( trans_batch_property = "transBatchA" if self._pos == 1 else "transBatchB" trans_batch = fused_node.attributes.get_int(trans_batch_property, 0) transposed_node = _get_node(transposed, "Transpose") - perm = transposed_node.attributes["perm"].as_ints() + perm = list(transposed_node.attributes["perm"].as_ints()) if not perm: return check_result.fail("Permutation values for Transpose are not correct.") @@ -296,7 +296,7 @@ def check(self, context, x, y, transposed: ir.Value, **_) -> orp.MatchResult: if _ir_utils.has_rank(x, 2) and _ir_utils.has_rank(y, 2): if perm: # Check that the two dimensions are swapped - if perm != [1, 0]: + if tuple(perm) != (1, 0): return check_result.fail( "Permutation values for Transpose are not correct." ) diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py index 527d4826d5..f82702d557 100644 --- a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py @@ -284,7 +284,7 @@ def _check_model( opt = onnx.reference.ReferenceEvaluator(optimized_model, new_ops=[FusedMatMul]) expected = ref.run(None, feeds) got = opt.run(None, feeds) - self.assertEqual(len(expected), len(got)) + self.assertEqual(len(got), len(expected)) for a, b in zip(expected, got): np.testing.assert_allclose(a, b, atol=atol, rtol=rtol) @@ -319,7 +319,7 @@ def test_fused_matmul_div_models(self, name, script_func, input_types, output_ty rule_set = fused_matmul_rule_sets.fused_matmul_rule_sets() rule_set.apply_to_model(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) - self.assertEqual(["Constant", "FusedMatMul"], [n.op_type for n in ir_model.graph]) + self.assertEqual([n.op_type for n in ir_model.graph], ["Constant", "FusedMatMul"]) self._check_model(model_proto, rewritten_model, atol=1e-6) @parameterized.parameterized.expand( @@ -354,7 +354,7 @@ def test_fused_matmul_with_transpose(self, _, script_func): ir_model = ir.serde.deserialize_model(model_proto) self._apply_fusion_rules(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) - self.assertEqual(["FusedMatMul"], [n.op_type for n in ir_model.graph]) + self.assertEqual([n.op_type for n in ir_model.graph], ["FusedMatMul"]) self._check_model(model_proto, rewritten_model, atol=1e-6) @parameterized.parameterized.expand([("should_not_match", _should_not_match)]) @@ -366,8 +366,8 @@ def test_should_not_match(self, _, script_func): self._apply_fusion_rules(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) self.assertEqual( - ["Transpose", "MatMul", "Transpose"], [n.op_type for n in ir_model.graph], + ["Transpose", "MatMul", "Transpose"], ) self._check_model(model_proto, rewritten_model, atol=1e-6) @@ -391,7 +391,7 @@ def test_fused_matmul_with_other_node_in_middle(self, _, script_func): common_passes.ShapeInferencePass()(ir_model) self._apply_fusion_rules(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) - self.assertEqual(["Identity", "FusedMatMul"], [n.op_type for n in ir_model.graph]) + self.assertEqual([n.op_type for n in ir_model.graph], ["Identity", "FusedMatMul"]) self._check_model(model_proto, rewritten_model, atol=1e-6) @parameterized.parameterized.expand( @@ -440,7 +440,7 @@ def test_transpose_fused_matmul_with_batch(self, _, script_func): ir_model = ir.serde.deserialize_model(model_proto) self._apply_fusion_rules(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) - self.assertEqual(["FusedMatMul"], [n.op_type for n in ir_model.graph]) + self.assertEqual([n.op_type for n in ir_model.graph], ["FusedMatMul"]) self._check_model(model_proto, rewritten_model, atol=1e-6) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 99852f712a..5fff910bcf 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -247,7 +247,7 @@ def check( bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(bindings, val, dims) + return not _fusion_utils.check_shape_bool(bindings, val, dims) if no_match(query_BSD, ["B", "S", "D"]): return False diff --git a/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py b/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py index 0d404b2754..51355fc8cf 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py +++ b/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py @@ -84,7 +84,7 @@ def check( self.bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(self.bindings, val, dims) + return not _fusion_utils.check_shape_bool(self.bindings, val, dims) # Check that if x is being split into q, k, v correctly # based on hidden sizes diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index e2987cfc5e..321e895f44 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -157,7 +157,7 @@ def check( bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(bindings, val, dims) + return not _fusion_utils.check_shape_bool(bindings, val, dims) if no_match(query_BSD, ["B", "S", "D"]): return check_result.fail( diff --git a/onnxscript/rewriter/ort_fusions/mha_bias.py b/onnxscript/rewriter/ort_fusions/mha_bias.py index 28b9646ddc..9ecf2ce017 100644 --- a/onnxscript/rewriter/ort_fusions/mha_bias.py +++ b/onnxscript/rewriter/ort_fusions/mha_bias.py @@ -78,7 +78,7 @@ def check( self.bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(self.bindings, val, dims) + return not _fusion_utils.check_shape_bool(self.bindings, val, dims) if query_matmul.dtype not in valid_float_types: return check_result.fail("Query is not a float or float16 type.", query_matmul) diff --git a/onnxscript/rewriter/ort_fusions/skip_normalization.py b/onnxscript/rewriter/ort_fusions/skip_normalization.py index f7a376aef9..c76a7454cb 100644 --- a/onnxscript/rewriter/ort_fusions/skip_normalization.py +++ b/onnxscript/rewriter/ort_fusions/skip_normalization.py @@ -60,7 +60,7 @@ def check( bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(bindings, val, dims) + return not _fusion_utils.check_shape_bool(bindings, val, dims) if no_match(input, ["B", "S", "D"]): return check_result.fail( @@ -184,7 +184,7 @@ def check( bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(bindings, val, dims) + return not _fusion_utils.check_shape_bool(bindings, val, dims) if no_match(input, ["B", "S", "D"]): return check_result.fail( diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 49ace2fb81..0a29080b4d 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -674,7 +674,7 @@ def test_model(x: FLOAT[1024, 512], y: FLOAT[1024, 512]) -> FLOAT[512, 1024]: function = model.functions[function_id] self.assertEqual([x.op_type for x in function], ["Add", "Transpose"]) transpose_node = function[1] - self.assertEqual(transpose_node.attributes["perm"].value, [1, 0]) + self.assertEqual(list(transpose_node.attributes["perm"].value), [1, 0]) onnxscript.optimizer.inline(model) self.assertEqual([x.op_type for x in model.graph], ["Add", "Transpose"])