diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index e69c40f2b0523..7444f70a46e93 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -771,7 +771,7 @@ OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } ShapedType inputTy = llvm::cast(getInput().getType()); \ if (!inputTy.hasRank()) \ return {}; \ - if (inputTy.getDimSize(getAxis()) == 1) \ + if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \ return getInput(); \ return {}; \ } @@ -874,7 +874,8 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) { return operandAttr; // If the dim-length is 1, tosa.reverse is a no-op. - if (operandTy.hasRank() && operandTy.getDimSize(axis) == 1) + if (operandTy.hasRank() && + (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1)) return operand; return {}; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 0b92a3cb7a620..1298518e7b6e6 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1015,7 +1015,7 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents( static LogicalResult ReduceInferReturnTypes( ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl &inferredReturnShapes) { - if (!operandShape.hasRank()) { + if (!operandShape.hasRank() || operandShape.getRank() == 0) { inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); return success(); } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 323864ea90130..5604a21ae4537 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -591,3 +591,15 @@ func.func @fold_abs_abs(%arg0: tensor) -> tensor { %1 = tosa.abs %0 : (tensor) -> tensor return %1 : tensor } + +// ----- + +// CHECK-LABEL: @fold_reduce_rank_zero +func.func nested @fold_reduce_rank_zero() { + // CHECK-NOT: tosa.reduce_min + // CHECK-NOT: tosa.reverse + %0 = tensor.empty() : tensor + %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor) -> tensor<1x10xi32> + %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor) -> tensor<1x10xi32> + return +}