Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: (onnx) fix lower onnx.Relu lowering #2435

Merged
merged 9 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 34 additions & 23 deletions tests/filecheck/transforms/convert_onnx_to_linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,35 @@
%t0, %t1 = "test.op"() : () -> (tensor<3x2xf32>, tensor<3x2xf32>)
%res_add = onnx.Add(%t0, %t1) {onnx_node_name = "/Add"} : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>


// CHECK: builtin.module {
// CHECK-NEXT: %t0, %t1 = "test.op"() : () -> (tensor<3x2xf32>, tensor<3x2xf32>)
// CHECK-NEXT: %res_add = tensor.empty() : tensor<3x2xf32>
// CHECK-NEXT: %res_add_1 = linalg.add ins(%t0, %t1 : tensor<3x2xf32>, tensor<3x2xf32>) outs(%res_add : tensor<3x2xf32>) -> tensor<3x2xf32>

%t2 = "test.op"() : () -> (tensor<3x4xf64>)
%res_relu = "onnx.Relu"(%t2) {onnx_node_name = "/Relu"}: (tensor<3x4xf64>) -> tensor<3x4xf64>
%t2 = "test.op"() : () -> (tensor<3x4xf32>)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surely we should support all float types? Can you get the arg type from the parameter type in the rewrite pattern?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say let's duplicate this test, to make sure that we handle both f32 and f64?

%res_relu = "onnx.Relu"(%t2) {onnx_node_name = "/Relu"}: (tensor<3x4xf32>) -> tensor<3x4xf32>

// CHECK-NEXT: %t2 = "test.op"() : () -> tensor<3x4xf32>
// CHECK-NEXT: %res_relu = tensor.empty() : tensor<3x4xf32>
// CHECK-NEXT: %res_relu_1 = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %res_relu_2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%t2 : tensor<3x4xf32>) outs(%res_relu : tensor<3x4xf32>) {
// CHECK-NEXT: ^0(%0 : f32, %1 : f32):
// CHECK-NEXT: %2 = arith.maximumf %0, %res_relu_1 : f32
// CHECK-NEXT: linalg.yield %2 : f32
// CHECK-NEXT: } -> tensor<3x4xf32>

%t27 = "test.op"() : () -> (tensor<3x4xf64>)
%res_relu_3 = "onnx.Relu"(%t27) {onnx_node_name = "/Relu"}: (tensor<3x4xf64>) -> tensor<3x4xf64>

// CHECK-NEXT: %t27 = "test.op"() : () -> tensor<3x4xf64>
// CHECK-NEXT: %res_relu_3 = tensor.empty() : tensor<3x4xf64>
// CHECK-NEXT: %res_relu_3_1 = arith.constant 0.000000e+00 : f64
// CHECK-NEXT: %res_relu_3_2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%t27 : tensor<3x4xf64>) outs(%res_relu_3 : tensor<3x4xf64>) {
// CHECK-NEXT: ^1(%3 : f64, %4 : f64):
// CHECK-NEXT: %5 = arith.maximumf %3, %res_relu_3_1 : f64
// CHECK-NEXT: linalg.yield %5 : f64
// CHECK-NEXT: } -> tensor<3x4xf64>

// CHECK-NEXT: %t2 = "test.op"() : () -> tensor<3x4xf64>
// CHECK-NEXT: %res_relu = tensor.empty() : tensor<3x4xf64>
// CHECK-NEXT: %res_relu_1 = arith.constant 0.000000e+00 : f64
// CHECK-NEXT: %res_relu_2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%t2 : tensor<3x4xf64>) outs(%res_relu : tensor<3x4xf64>) {
// CHECK-NEXT: ^0(%0 : f64, %1 : f64):
// CHECK-NEXT: %2 = arith.maximumf %0, %res_relu_1 : f64
// CHECK-NEXT: linalg.yield %2 : f64
// CHECK-NEXT: } -> tensor<3x4xf64>

%t3,%t4 = "test.op"(): () -> (tensor<20x2xf32>, tensor<2xi64>)
%res_reshape = "onnx.Reshape"(%t3, %t4) {onnx_node_name = "/Reshape"}: (tensor<20x2xf32>, tensor<2xi64>) -> tensor<1x40xf32>
Expand All @@ -31,10 +43,10 @@
%res_gemm= "onnx.Gemm"(%t5, %t6, %t7) {onnx_node_name = "/Gemm", "alpha" = 1.000000e+00 : f32, "beta" = 1.000000e+00 : f32, "transA" = 0 : si64, "transB" = 1 : si64}: (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) -> tensor<1x50xf32>

// CHECK-NEXT: %t5, %t6, %t7 = "test.op"() : () -> (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>)
// CHECK-NEXT: %3 = tensor.empty() : tensor<320x50xf32>
// CHECK-NEXT: %4 = linalg.transpose ins(%t6:tensor<50x320xf32>) outs(%3:tensor<320x50xf32>) permutation = [1, 0]
// CHECK-NEXT: %6 = tensor.empty() : tensor<320x50xf32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to avoid these spurious changes, I'd recommend using regex patterns:

Suggested change
// CHECK-NEXT: %6 = tensor.empty() : tensor<320x50xf32>
// CHECK-NEXT: %{{.*}} = tensor.empty() : tensor<320x50xf32>

You can also pip install filecheckize and use it with --mlir-anonymize to convert MLIR IR to filechek tests, something like this:
xdsl-opt file.mlir -p my-awesome-pass | filecheckize --mlir-anonymize

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops i merged it, ill keep these in mind for future, thank you

// CHECK-NEXT: %7 = linalg.transpose ins(%t6:tensor<50x320xf32>) outs(%6:tensor<320x50xf32>) permutation = [1, 0]
// CHECK-NEXT: %res_gemm = tensor.empty() : tensor<1x50xf32>
// CHECK-NEXT: %res_gemm_1 = linalg.matmul ins(%t5, %4 : tensor<1x320xf32>, tensor<320x50xf32>) outs(%res_gemm : tensor<1x50xf32>) -> tensor<1x50xf32>
// CHECK-NEXT: %res_gemm_1 = linalg.matmul ins(%t5, %7 : tensor<1x320xf32>, tensor<320x50xf32>) outs(%res_gemm : tensor<1x50xf32>) -> tensor<1x50xf32>
// CHECK-NEXT: %res_gemm_2 = linalg.add ins(%res_gemm_1, %t7 : tensor<1x50xf32>, tensor<50xf32>) outs(%res_gemm_1 : tensor<1x50xf32>) -> tensor<1x50xf32>


Expand All @@ -54,15 +66,15 @@


// CHECK-NEXT: %t11, %t12, %t13 = "test.op"() : () -> (tensor<10x5xf32>, tensor<10x3xf32>, tensor<5x3xf32>)
// CHECK-NEXT: %5 = tensor.empty() : tensor<5x10xf32>
// CHECK-NEXT: %6 = linalg.transpose ins(%t11:tensor<10x5xf32>) outs(%5:tensor<5x10xf32>) permutation = [1, 0]
// CHECK-NEXT: %7 = arith.constant 5.000000e-01 : f32
// CHECK-NEXT: %8 = linalg.mul ins(%7, %6 : f32, tensor<5x10xf32>) outs(%6 : tensor<5x10xf32>) -> tensor<5x10xf32>
// CHECK-NEXT: %9 = arith.constant 5.000000e-01 : f32
// CHECK-NEXT: %10 = linalg.mul ins(%9, %t13 : f32, tensor<5x3xf32>) outs(%t13 : tensor<5x3xf32>) -> tensor<5x3xf32>
// CHECK-NEXT: %8 = tensor.empty() : tensor<5x10xf32>
// CHECK-NEXT: %9 = linalg.transpose ins(%t11:tensor<10x5xf32>) outs(%8:tensor<5x10xf32>) permutation = [1, 0]
// CHECK-NEXT: %10 = arith.constant 5.000000e-01 : f32
// CHECK-NEXT: %11 = linalg.mul ins(%10, %9 : f32, tensor<5x10xf32>) outs(%9 : tensor<5x10xf32>) -> tensor<5x10xf32>
// CHECK-NEXT: %12 = arith.constant 5.000000e-01 : f32
// CHECK-NEXT: %13 = linalg.mul ins(%12, %t13 : f32, tensor<5x3xf32>) outs(%t13 : tensor<5x3xf32>) -> tensor<5x3xf32>
// CHECK-NEXT: %res_gemm_2 = tensor.empty() : tensor<5x3xf32>
// CHECK-NEXT: %res_gemm_2_1 = linalg.matmul ins(%8, %t12 : tensor<5x10xf32>, tensor<10x3xf32>) outs(%res_gemm_2 : tensor<5x3xf32>) -> tensor<5x3xf32>
// CHECK-NEXT: %res_gemm_2_2 = linalg.add ins(%res_gemm_2_1, %10 : tensor<5x3xf32>, tensor<5x3xf32>) outs(%res_gemm_2_1 : tensor<5x3xf32>) -> tensor<5x3xf32>
// CHECK-NEXT: %res_gemm_2_1 = linalg.matmul ins(%11, %t12 : tensor<5x10xf32>, tensor<10x3xf32>) outs(%res_gemm_2 : tensor<5x3xf32>) -> tensor<5x3xf32>
// CHECK-NEXT: %res_gemm_2_2 = linalg.add ins(%res_gemm_2_1, %13 : tensor<5x3xf32>, tensor<5x3xf32>) outs(%res_gemm_2_1 : tensor<5x3xf32>) -> tensor<5x3xf32>

%t26 = "test.op"(): () -> (tensor<1x16x14x14xf32>)
%res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t26) {onnx_node_name = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : si64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : si64, strides = [3 : i64, 3 : i64]} : (tensor<1x16x14x14xf32>) -> tensor<1x16x4x4xf32>
Expand All @@ -89,7 +101,6 @@
// CHECK-NEXT: %res_conv_3_1 = linalg.conv_2d_nchw_fchw {"dilations" = dense<1> : tensor<2xi64>, "strides" = dense<1> : tensor<2xi64>} ins(%t23, %t24 : tensor<1x8x14x14xf32>, tensor<16x8x5x5xf32>) outs(%res_conv_3 : tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32>
// CHECK-NEXT: %res_conv_3_2 = linalg.add ins(%t25 : tensor<16xf32>) outs(%res_conv_3_1 : tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32>


%res_constant = "onnx.Constant"() {onnx_node_name = "/Constant", "value" = dense<1> : tensor<1xi64>}: () -> tensor<1xi64>
%res_constant_2 = "onnx.Constant"() {onnx_node_name = "/Constant", "value" = dense<2.0> : tensor<1x5xf32>} : () -> tensor<1x5xf32>

Expand Down
23 changes: 12 additions & 11 deletions xdsl/transforms/convert_onnx_to_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from xdsl.dialects import arith, linalg, ml_program, onnx, tensor
from xdsl.dialects.builtin import (
AffineMapAttr,
AnyFloat,
DenseArrayBase,
DenseIntOrFPElementsAttr,
FloatAttr,
Expand Down Expand Up @@ -62,24 +63,24 @@ def match_and_rewrite(self, add: onnx.Add, rewriter: PatternRewriter, /):
class ReluOpLowering(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, relu: onnx.Relu, rewriter: PatternRewriter, /):
body = Region(Block(arg_types=(f64, f64)))
affine_map = AffineMapAttr(AffineMap.from_callable(lambda d0, d1: (d0, d1)))
operand = relu.operand.type
assert isinstance(operand, TensorType)
superlopuh marked this conversation as resolved.
Show resolved Hide resolved
operand = cast(TensorType[Attribute], operand)
operand_rank = len(operand.get_shape())
body = Region(Block(arg_types=(operand.element_type, operand.element_type)))
superlopuh marked this conversation as resolved.
Show resolved Hide resolved
affine_map = AffineMapAttr(AffineMap.identity(operand_rank))
rewriter.replace_matched_op(
(
empty := tensor.EmptyOp((), relu.res.type),
zero := arith.Constant(FloatAttr(0, f64)),
zero := arith.Constant(
FloatAttr(0.0, cast(AnyFloat, operand.element_type))
),
linalg.Generic(
(relu.operand,),
(empty.tensor,),
body,
(
affine_map,
affine_map,
),
(
linalg.IteratorTypeAttr.parallel(),
linalg.IteratorTypeAttr.parallel(),
),
(affine_map, affine_map),
(linalg.IteratorTypeAttr.parallel(),) * operand_rank,
(relu.res.type,),
),
)
Expand Down
Loading