-
Notifications
You must be signed in to change notification settings - Fork 12.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tosa] Add verifier for ArgMax
operator
#68410
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa ChangesVerifier ensures that operator is valid by checking:
Full diff: https://github.com/llvm/llvm-project/pull/68410.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index f4d9a251fb97839..a80111aedfe0b59 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -48,6 +48,8 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
let results = (outs
Tosa_Tensor: $output
);
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0b92a3cb7a6203d..af112aa65e2a371 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -211,6 +211,21 @@ template <typename T> static LogicalResult verifyConvOp(T op) {
return success();
}
+LogicalResult tosa::ArgMaxOp::verify() {
+ // Ensure output is of 32-bit integer
+ const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
+ if (!resultETy.isInteger(32))
+ return emitOpError("result tensor is not i32");
+
+ // Ensure axis is within the tensor rank
+ const auto inputType = llvm::cast<ShapedType>(getInput().getType());
+ const int64_t axis = getAxisAttr().getInt();
+ if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
+ return emitOpError("specified axis is outside the rank of the tensor");
+
+ return success();
+}
+
LogicalResult tosa::AvgPool2dOp::verify() {
auto inputType = llvm::cast<ShapedType>(getInput().getType());
if (hasZeroDimension(inputType))
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 323864ea9013048..d36cf6a1d94a9f3 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,10 +1,10 @@
// RUN: mlir-opt -canonicalize="test-convergence" %s | FileCheck %s
// CHECK-LABEL: @argmax_nofold
-func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
+func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
// CHECK: tosa.argmax
- %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
- return %0 : tensor<?x1xf32>
+ %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xi32>
+ return %0 : tensor<?x1xi32>
}
// CHECK-LABEL: @add_bcast_zero_int
diff --git a/mlir/test/Dialect/Tosa/constrained_shapes.mlir b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
index 9acb024cf78d005..8c3ad828ab06f81 100644
--- a/mlir/test/Dialect/Tosa/constrained_shapes.mlir
+++ b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
@@ -6,6 +6,6 @@
// Uses argmax as canonical example to validate constrained TOSA tensor shapes.
// CHECK-LABEL: argmax
func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<?xi32> {
- %0 = "tosa.argmax"(%arg0) {axis = 1 : i32} : (tensor<?xf32>) -> tensor<?xi32>
+ %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index e7fdf8af409b564..68238087f5c2523 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1,10 +1,10 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate
-func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32> {
+func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
// expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
- %0 = "tosa.argmax"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32>
- return %0 : tensor<1x1x1x1x29x4xf32>
+ %0 = "tosa.argmax"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32>
+ return %0 : tensor<1x1x1x1x29x4xi32>
}
// -----
|
I think checking the axis in the verifier is a good idea. For the size check, I know that the current spec calls out i32 as the return type, but anticipating future changes where that is changed to a larger number, this feels like something that should go into the opt-in pass that takes a profile to validate against. That profile check would fail today, but implementations that don't strictly check can still use the operator set. |
Should we check that the return type is of an integer type? Not of a particular bitwidth per se. |
Verifier ensures that operator is valid by checking: * Output type is of integer type * Axis is within the rank of the tensor Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
a0b027a
to
69ebdcd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Int or Index seems to be the correct type check to verify for returns.
Verifier ensures that operator is valid by checking: