From 69ebdcd4a57e0086e552855001bb2ca0e30a82f7 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Thu, 5 Oct 2023 15:19:04 +0100 Subject: [PATCH] [mlir][tosa] Add verifier for `ArgMax` operator Verifier ensures that operator is valid by checking: * Output type is of integer type * Axis is within the rank of the tensor Signed-off-by: Georgios Pinitas --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 2 ++ mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 15 +++++++++++++++ mlir/test/Dialect/Tosa/canonicalize.mlir | 6 +++--- mlir/test/Dialect/Tosa/constrained_shapes.mlir | 2 +- mlir/test/Dialect/Tosa/level_check.mlir | 6 +++--- 5 files changed, 24 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index f4d9a251fb978..a80111aedfe0b 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -48,6 +48,8 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> { let results = (outs Tosa_Tensor: $output ); + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 0b92a3cb7a620..a719171b2b359 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -211,6 +211,21 @@ template static LogicalResult verifyConvOp(T op) { return success(); } +LogicalResult tosa::ArgMaxOp::verify() { + // Ensure output is of 32-bit integer + const auto resultETy = llvm::cast(getType()).getElementType(); + if (!resultETy.isIntOrIndex()) + return emitOpError("result tensor is not of integer type"); + + // Ensure axis is within the tensor rank + const auto inputType = llvm::cast(getInput().getType()); + const int64_t axis = getAxisAttr().getInt(); + if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank())) + return emitOpError("specified axis is outside the rank of the tensor"); + + return success(); +} + LogicalResult tosa::AvgPool2dOp::verify() { auto inputType = llvm::cast(getInput().getType()); if (hasZeroDimension(inputType)) diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 323864ea90130..d36cf6a1d94a9 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -1,10 +1,10 @@ // RUN: mlir-opt -canonicalize="test-convergence" %s | FileCheck %s // CHECK-LABEL: @argmax_nofold -func.func @argmax_nofold(%arg0: tensor) -> tensor { +func.func @argmax_nofold(%arg0: tensor) -> tensor { // CHECK: tosa.argmax - %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor) -> tensor - return %0 : tensor + %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor) -> tensor + return %0 : tensor } // CHECK-LABEL: @add_bcast_zero_int diff --git a/mlir/test/Dialect/Tosa/constrained_shapes.mlir b/mlir/test/Dialect/Tosa/constrained_shapes.mlir index 9acb024cf78d0..8c3ad828ab06f 100644 --- a/mlir/test/Dialect/Tosa/constrained_shapes.mlir +++ b/mlir/test/Dialect/Tosa/constrained_shapes.mlir @@ -6,6 +6,6 @@ // Uses argmax as canonical example to validate constrained TOSA tensor shapes. // CHECK-LABEL: argmax func.func @test_argmax(%arg0: tensor) -> tensor { - %0 = "tosa.argmax"(%arg0) {axis = 1 : i32} : (tensor) -> tensor + %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor) -> tensor return %0 : tensor } diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index e7fdf8af409b5..68238087f5c25 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1,10 +1,10 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate -func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32> { +func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> { // expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}} - %0 = "tosa.argmax"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32> - return %0 : tensor<1x1x1x1x29x4xf32> + %0 = "tosa.argmax"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> + return %0 : tensor<1x1x1x1x29x4xi32> } // -----