Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tosa] Add verifier for ArgMax operator #68410

Merged
merged 1 commit into from
Oct 9, 2023

Conversation

GeorgeARM
Copy link
Contributor

Verifier ensures that operator is valid by checking:

  • Output type is I32
  • Axis is within the rank of the tensor

@llvmbot
Copy link
Member

llvmbot commented Oct 6, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Changes

Verifier ensures that operator is valid by checking:

  • Output type is I32
  • Axis is within the rank of the tensor

Full diff: https://github.com/llvm/llvm-project/pull/68410.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+15)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/constrained_shapes.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+3-3)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index f4d9a251fb97839..a80111aedfe0b59 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -48,6 +48,8 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
   let results = (outs
     Tosa_Tensor: $output
   );
+
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0b92a3cb7a6203d..af112aa65e2a371 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -211,6 +211,21 @@ template <typename T> static LogicalResult verifyConvOp(T op) {
   return success();
 }
 
+LogicalResult tosa::ArgMaxOp::verify() {
+  // Ensure output is of 32-bit integer
+  const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
+  if (!resultETy.isInteger(32))
+    return emitOpError("result tensor is not i32");
+
+  // Ensure axis is within the tensor rank
+  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
+  const int64_t axis = getAxisAttr().getInt();
+  if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
+    return emitOpError("specified axis is outside the rank of the tensor");
+
+  return success();
+}
+
 LogicalResult tosa::AvgPool2dOp::verify() {
   auto inputType = llvm::cast<ShapedType>(getInput().getType());
   if (hasZeroDimension(inputType))
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 323864ea9013048..d36cf6a1d94a9f3 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,10 +1,10 @@
 // RUN: mlir-opt -canonicalize="test-convergence" %s | FileCheck %s
 
 // CHECK-LABEL: @argmax_nofold
-func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
+func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
   // CHECK: tosa.argmax
-  %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
-  return %0 : tensor<?x1xf32>
+  %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xi32>
+  return %0 : tensor<?x1xi32>
 }
 
 // CHECK-LABEL: @add_bcast_zero_int
diff --git a/mlir/test/Dialect/Tosa/constrained_shapes.mlir b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
index 9acb024cf78d005..8c3ad828ab06f81 100644
--- a/mlir/test/Dialect/Tosa/constrained_shapes.mlir
+++ b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
@@ -6,6 +6,6 @@
 // Uses argmax as canonical example to validate constrained TOSA tensor shapes.
 // CHECK-LABEL: argmax
 func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<?xi32> {
-  %0 = "tosa.argmax"(%arg0) {axis = 1 : i32} : (tensor<?xf32>) -> tensor<?xi32>
+  %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<?xi32>
   return %0 : tensor<?xi32>
 }
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index e7fdf8af409b564..68238087f5c2523 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1,10 +1,10 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate
 
 
-func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32> {
+func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
   // expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = "tosa.argmax"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32>
-  return %0 : tensor<1x1x1x1x29x4xf32>
+  %0 = "tosa.argmax"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32>
+  return %0 : tensor<1x1x1x1x29x4xi32>
 }
 
 // -----

@eric-k256 eric-k256 self-requested a review October 6, 2023 23:45
@eric-k256
Copy link
Contributor

I think checking the axis in the verifier is a good idea. For the size check, I know that the current spec calls out i32 as the return type, but anticipating future changes where that is changed to a larger number, this feels like something that should go into the opt-in pass that takes a profile to validate against. That profile check would fail today, but implementations that don't strictly check can still use the operator set.

@GeorgeARM
Copy link
Contributor Author

I think checking the axis in the verifier is a good idea. For the size check, I know that the current spec calls out i32 as the return type, but anticipating future changes where that is changed to a larger number, this feels like something that should go into the opt-in pass that takes a profile to validate against. That profile check would fail today, but implementations that don't strictly check can still use the operator set.

Should we check that the return type is of an integer type? Not of a particular bitwidth per se.

Verifier ensures that operator is valid by checking:

* Output type is of integer type
* Axis is within the rank of the tensor

Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Copy link
Contributor

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Int or Index seems to be the correct type check to verify for returns.

@eric-k256 eric-k256 merged commit 414709e into llvm:main Oct 9, 2023
2 checks passed
@GeorgeARM GeorgeARM deleted the verify_argmax branch October 9, 2023 15:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants