diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index b57e8b2fb8cb..4dc678e31ddb 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -442,6 +442,8 @@ def Tosa_AddOp : Tosa_Op<"add", [ let results = (outs Tosa_Tensor:$output ); + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 90146f5bc29b..07d3e6e67c21 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -289,6 +289,55 @@ void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +struct AddZeroOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::AddOp op, + PatternRewriter &rewriter) const override { + auto input1 = op.input1(); + auto input2 = op.input2(); + + DenseElementsAttr input1Attr; + if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() && + input2.getType() == op.getType()) { + if (input1Attr.getType().getElementType().isa() && + input1Attr.getSplatValue().isZero()) { + rewriter.replaceOp(op, op.input2()); + return success(); + } + + if (input1Attr.getType().getElementType().isa() && + input1Attr.getSplatValue().isZero()) { + rewriter.replaceOp(op, op.input2()); + return success(); + } + } + + DenseElementsAttr input2Attr; + if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() && + input1.getType() == op.getType()) { + if (input2Attr.getType().getElementType().isa() && + input2Attr.getSplatValue().isZero()) { + rewriter.replaceOp(op, op.input1()); + return success(); + } + + if (input2Attr.getType().getElementType().isa() && + input2Attr.getSplatValue().isZero()) { + rewriter.replaceOp(op, op.input1()); + return success(); + } + } + + return failure(); + } +}; + +void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // Operator Folders. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index e6cf1a15ac67..e4614853c71e 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -9,6 +9,38 @@ func @argmax_nofold(%arg0: tensor) -> tensor { // ----- +// CHECK-LABEL: @add_zero_different_shape +func @add_zero_different_shape(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> { + // CHECK: tosa.add + %zeros = "tosa.const"() {value = dense<0.0> : tensor<4x2x3xf32>} : () -> tensor<4x2x3xf32> + %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xf32>, tensor<4x2x3xf32>) -> tensor<4x2x3xf32> + return %1 : tensor<4x2x3xf32> +} + +// ----- + +// CHECK-LABEL: @add_zero_float +func @add_zero_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + // CHECK: return %arg0 + // CHECK-NOT: tosa.add + %zeros = "tosa.const"() {value = dense<0.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + return %1 : tensor<2x3xf32> +} + +// ----- + +// CHECK-LABEL: @add_zero_int +func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> { + // CHECK: return %arg0 + // CHECK-NOT: tosa.add + %zeros = "tosa.const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32> + %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + return %1 : tensor<2x3xi32> +} + +// ----- + // CHECK-LABEL: @cast_fold func @cast_fold(%arg0: tensor) -> tensor { // CHECK: return %arg0