-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
transformations: (onnx) fix lower onnx.Relu lowering #2435
Changes from all commits
d3c377d
2ccbb33
c29c6b7
7d5e3d2
465ff6d
b09def4
1adcb81
eaa8c32
7216b53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -3,23 +3,35 @@ | |||||
%t0, %t1 = "test.op"() : () -> (tensor<3x2xf32>, tensor<3x2xf32>) | ||||||
%res_add = onnx.Add(%t0, %t1) {onnx_node_name = "/Add"} : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> | ||||||
|
||||||
|
||||||
// CHECK: builtin.module { | ||||||
// CHECK-NEXT: %t0, %t1 = "test.op"() : () -> (tensor<3x2xf32>, tensor<3x2xf32>) | ||||||
// CHECK-NEXT: %res_add = tensor.empty() : tensor<3x2xf32> | ||||||
// CHECK-NEXT: %res_add_1 = linalg.add ins(%t0, %t1 : tensor<3x2xf32>, tensor<3x2xf32>) outs(%res_add : tensor<3x2xf32>) -> tensor<3x2xf32> | ||||||
|
||||||
%t2 = "test.op"() : () -> (tensor<3x4xf64>) | ||||||
%res_relu = "onnx.Relu"(%t2) {onnx_node_name = "/Relu"}: (tensor<3x4xf64>) -> tensor<3x4xf64> | ||||||
%t2 = "test.op"() : () -> (tensor<3x4xf32>) | ||||||
%res_relu = "onnx.Relu"(%t2) {onnx_node_name = "/Relu"}: (tensor<3x4xf32>) -> tensor<3x4xf32> | ||||||
|
||||||
// CHECK-NEXT: %t2 = "test.op"() : () -> tensor<3x4xf32> | ||||||
// CHECK-NEXT: %res_relu = tensor.empty() : tensor<3x4xf32> | ||||||
// CHECK-NEXT: %res_relu_1 = arith.constant 0.000000e+00 : f32 | ||||||
// CHECK-NEXT: %res_relu_2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%t2 : tensor<3x4xf32>) outs(%res_relu : tensor<3x4xf32>) { | ||||||
// CHECK-NEXT: ^0(%0 : f32, %1 : f32): | ||||||
// CHECK-NEXT: %2 = arith.maximumf %0, %res_relu_1 : f32 | ||||||
// CHECK-NEXT: linalg.yield %2 : f32 | ||||||
// CHECK-NEXT: } -> tensor<3x4xf32> | ||||||
|
||||||
%t27 = "test.op"() : () -> (tensor<3x4xf64>) | ||||||
%res_relu_3 = "onnx.Relu"(%t27) {onnx_node_name = "/Relu"}: (tensor<3x4xf64>) -> tensor<3x4xf64> | ||||||
|
||||||
// CHECK-NEXT: %t27 = "test.op"() : () -> tensor<3x4xf64> | ||||||
// CHECK-NEXT: %res_relu_3 = tensor.empty() : tensor<3x4xf64> | ||||||
// CHECK-NEXT: %res_relu_3_1 = arith.constant 0.000000e+00 : f64 | ||||||
// CHECK-NEXT: %res_relu_3_2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%t27 : tensor<3x4xf64>) outs(%res_relu_3 : tensor<3x4xf64>) { | ||||||
// CHECK-NEXT: ^1(%3 : f64, %4 : f64): | ||||||
// CHECK-NEXT: %5 = arith.maximumf %3, %res_relu_3_1 : f64 | ||||||
// CHECK-NEXT: linalg.yield %5 : f64 | ||||||
// CHECK-NEXT: } -> tensor<3x4xf64> | ||||||
|
||||||
// CHECK-NEXT: %t2 = "test.op"() : () -> tensor<3x4xf64> | ||||||
// CHECK-NEXT: %res_relu = tensor.empty() : tensor<3x4xf64> | ||||||
// CHECK-NEXT: %res_relu_1 = arith.constant 0.000000e+00 : f64 | ||||||
// CHECK-NEXT: %res_relu_2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%t2 : tensor<3x4xf64>) outs(%res_relu : tensor<3x4xf64>) { | ||||||
// CHECK-NEXT: ^0(%0 : f64, %1 : f64): | ||||||
// CHECK-NEXT: %2 = arith.maximumf %0, %res_relu_1 : f64 | ||||||
// CHECK-NEXT: linalg.yield %2 : f64 | ||||||
// CHECK-NEXT: } -> tensor<3x4xf64> | ||||||
|
||||||
%t3,%t4 = "test.op"(): () -> (tensor<20x2xf32>, tensor<2xi64>) | ||||||
%res_reshape = "onnx.Reshape"(%t3, %t4) {onnx_node_name = "/Reshape"}: (tensor<20x2xf32>, tensor<2xi64>) -> tensor<1x40xf32> | ||||||
|
@@ -31,10 +43,10 @@ | |||||
%res_gemm= "onnx.Gemm"(%t5, %t6, %t7) {onnx_node_name = "/Gemm", "alpha" = 1.000000e+00 : f32, "beta" = 1.000000e+00 : f32, "transA" = 0 : si64, "transB" = 1 : si64}: (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) -> tensor<1x50xf32> | ||||||
|
||||||
// CHECK-NEXT: %t5, %t6, %t7 = "test.op"() : () -> (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) | ||||||
// CHECK-NEXT: %3 = tensor.empty() : tensor<320x50xf32> | ||||||
// CHECK-NEXT: %4 = linalg.transpose ins(%t6:tensor<50x320xf32>) outs(%3:tensor<320x50xf32>) permutation = [1, 0] | ||||||
// CHECK-NEXT: %6 = tensor.empty() : tensor<320x50xf32> | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In order to avoid these spurious changes, I'd recommend using regex patterns:
Suggested change
You can also There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oops i merged it, ill keep these in mind for future, thank you |
||||||
// CHECK-NEXT: %7 = linalg.transpose ins(%t6:tensor<50x320xf32>) outs(%6:tensor<320x50xf32>) permutation = [1, 0] | ||||||
// CHECK-NEXT: %res_gemm = tensor.empty() : tensor<1x50xf32> | ||||||
// CHECK-NEXT: %res_gemm_1 = linalg.matmul ins(%t5, %4 : tensor<1x320xf32>, tensor<320x50xf32>) outs(%res_gemm : tensor<1x50xf32>) -> tensor<1x50xf32> | ||||||
// CHECK-NEXT: %res_gemm_1 = linalg.matmul ins(%t5, %7 : tensor<1x320xf32>, tensor<320x50xf32>) outs(%res_gemm : tensor<1x50xf32>) -> tensor<1x50xf32> | ||||||
// CHECK-NEXT: %res_gemm_2 = linalg.add ins(%res_gemm_1, %t7 : tensor<1x50xf32>, tensor<50xf32>) outs(%res_gemm_1 : tensor<1x50xf32>) -> tensor<1x50xf32> | ||||||
|
||||||
|
||||||
|
@@ -54,15 +66,15 @@ | |||||
|
||||||
|
||||||
// CHECK-NEXT: %t11, %t12, %t13 = "test.op"() : () -> (tensor<10x5xf32>, tensor<10x3xf32>, tensor<5x3xf32>) | ||||||
// CHECK-NEXT: %5 = tensor.empty() : tensor<5x10xf32> | ||||||
// CHECK-NEXT: %6 = linalg.transpose ins(%t11:tensor<10x5xf32>) outs(%5:tensor<5x10xf32>) permutation = [1, 0] | ||||||
// CHECK-NEXT: %7 = arith.constant 5.000000e-01 : f32 | ||||||
// CHECK-NEXT: %8 = linalg.mul ins(%7, %6 : f32, tensor<5x10xf32>) outs(%6 : tensor<5x10xf32>) -> tensor<5x10xf32> | ||||||
// CHECK-NEXT: %9 = arith.constant 5.000000e-01 : f32 | ||||||
// CHECK-NEXT: %10 = linalg.mul ins(%9, %t13 : f32, tensor<5x3xf32>) outs(%t13 : tensor<5x3xf32>) -> tensor<5x3xf32> | ||||||
// CHECK-NEXT: %8 = tensor.empty() : tensor<5x10xf32> | ||||||
// CHECK-NEXT: %9 = linalg.transpose ins(%t11:tensor<10x5xf32>) outs(%8:tensor<5x10xf32>) permutation = [1, 0] | ||||||
// CHECK-NEXT: %10 = arith.constant 5.000000e-01 : f32 | ||||||
// CHECK-NEXT: %11 = linalg.mul ins(%10, %9 : f32, tensor<5x10xf32>) outs(%9 : tensor<5x10xf32>) -> tensor<5x10xf32> | ||||||
// CHECK-NEXT: %12 = arith.constant 5.000000e-01 : f32 | ||||||
// CHECK-NEXT: %13 = linalg.mul ins(%12, %t13 : f32, tensor<5x3xf32>) outs(%t13 : tensor<5x3xf32>) -> tensor<5x3xf32> | ||||||
// CHECK-NEXT: %res_gemm_2 = tensor.empty() : tensor<5x3xf32> | ||||||
// CHECK-NEXT: %res_gemm_2_1 = linalg.matmul ins(%8, %t12 : tensor<5x10xf32>, tensor<10x3xf32>) outs(%res_gemm_2 : tensor<5x3xf32>) -> tensor<5x3xf32> | ||||||
// CHECK-NEXT: %res_gemm_2_2 = linalg.add ins(%res_gemm_2_1, %10 : tensor<5x3xf32>, tensor<5x3xf32>) outs(%res_gemm_2_1 : tensor<5x3xf32>) -> tensor<5x3xf32> | ||||||
// CHECK-NEXT: %res_gemm_2_1 = linalg.matmul ins(%11, %t12 : tensor<5x10xf32>, tensor<10x3xf32>) outs(%res_gemm_2 : tensor<5x3xf32>) -> tensor<5x3xf32> | ||||||
// CHECK-NEXT: %res_gemm_2_2 = linalg.add ins(%res_gemm_2_1, %13 : tensor<5x3xf32>, tensor<5x3xf32>) outs(%res_gemm_2_1 : tensor<5x3xf32>) -> tensor<5x3xf32> | ||||||
|
||||||
%t26 = "test.op"(): () -> (tensor<1x16x14x14xf32>) | ||||||
%res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t26) {onnx_node_name = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : si64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : si64, strides = [3 : i64, 3 : i64]} : (tensor<1x16x14x14xf32>) -> tensor<1x16x4x4xf32> | ||||||
|
@@ -89,7 +101,6 @@ | |||||
// CHECK-NEXT: %res_conv_3_1 = linalg.conv_2d_nchw_fchw {"dilations" = dense<1> : tensor<2xi64>, "strides" = dense<1> : tensor<2xi64>} ins(%t23, %t24 : tensor<1x8x14x14xf32>, tensor<16x8x5x5xf32>) outs(%res_conv_3 : tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32> | ||||||
// CHECK-NEXT: %res_conv_3_2 = linalg.add ins(%t25 : tensor<16xf32>) outs(%res_conv_3_1 : tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32> | ||||||
|
||||||
|
||||||
%res_constant = "onnx.Constant"() {onnx_node_name = "/Constant", "value" = dense<1> : tensor<1xi64>}: () -> tensor<1xi64> | ||||||
%res_constant_2 = "onnx.Constant"() {onnx_node_name = "/Constant", "value" = dense<2.0> : tensor<1x5xf32>} : () -> tensor<1x5xf32> | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Surely we should support all float types? Can you get the arg type from the parameter type in the rewrite pattern?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say let's duplicate this test, to make sure that we handle both f32 and f64?