diff --git a/tests/filecheck/transforms/convert_onnx_to_linalg.mlir b/tests/filecheck/transforms/convert_onnx_to_linalg.mlir index fc38b13f04..dc45785d8b 100644 --- a/tests/filecheck/transforms/convert_onnx_to_linalg.mlir +++ b/tests/filecheck/transforms/convert_onnx_to_linalg.mlir @@ -1,12 +1,17 @@ // RUN: xdsl-opt -p convert-onnx-to-linalg %s | filecheck %s +// CHECK: builtin.module { + %t0, %t1 = "test.op"() : () -> (tensor<3x2xf32>, tensor<3x2xf32>) %res_add = onnx.Add(%t0, %t1) {onnx_node_name = "/Add"} : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> +%res_sub = onnx.Sub(%t0, %t1) {onnx_node_name = "/Sub"} : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> -// CHECK: builtin.module { // CHECK-NEXT: %t0, %t1 = "test.op"() : () -> (tensor<3x2xf32>, tensor<3x2xf32>) // CHECK-NEXT: %res_add = tensor.empty() : tensor<3x2xf32> // CHECK-NEXT: %res_add_1 = linalg.add ins(%t0, %t1 : tensor<3x2xf32>, tensor<3x2xf32>) outs(%res_add : tensor<3x2xf32>) -> tensor<3x2xf32> +// CHECK-NEXT: %res_sub = tensor.empty() : tensor<3x2xf32> +// CHECK-NEXT: %res_sub_1 = linalg.sub ins(%t0, %t1 : tensor<3x2xf32>, tensor<3x2xf32>) outs(%res_sub : tensor<3x2xf32>) -> tensor<3x2xf32> + %t2 = "test.op"() : () -> (tensor<3x4xf32>) %res_relu = "onnx.Relu"(%t2) {onnx_node_name = "/Relu"}: (tensor<3x4xf32>) -> tensor<3x4xf32> diff --git a/xdsl/transforms/convert_onnx_to_linalg.py b/xdsl/transforms/convert_onnx_to_linalg.py index a796efeda3..94626a57cc 100644 --- a/xdsl/transforms/convert_onnx_to_linalg.py +++ b/xdsl/transforms/convert_onnx_to_linalg.py @@ -59,6 +59,27 @@ def match_and_rewrite(self, add: onnx.Add, rewriter: PatternRewriter, /): ) +@dataclass +class SubOpLowering(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, sub: onnx.Sub, rewriter: PatternRewriter, /): + lhs_type = sub.lhs.type + rhs_type = sub.rhs.type + if isinstance(lhs_type, TensorType) and isinstance(rhs_type, TensorType): + lhs_shape = lhs_type.get_shape() + rhs_shape = rhs_type.get_shape() + + if -1 in lhs_shape or -1 in rhs_shape: + raise NotImplementedError() + + rewriter.replace_matched_op( + ( + empty := tensor.EmptyOp((), sub.res.type), + linalg.SubOp((sub.lhs, sub.rhs), (empty.tensor,), res=(sub.res.type,)), + ) + ) + + @dataclass class ReluOpLowering(RewritePattern): @op_type_rewrite_pattern @@ -378,6 +399,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None: GreedyRewritePatternApplier( [ AddOpLowering(), + SubOpLowering(), ReluOpLowering(), ConstantOpLowering(), ReshapeOpLowering(),