Skip to content

Commit afb0582

Browse files
committed
Fix TOSA verifier to emit verbose errors
Also as a test for invalid ops which was missing.
1 parent b6ccca2 commit afb0582

File tree

2 files changed

+55
-3
lines changed

2 files changed

+55
-3
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,14 @@ template <typename T> static LogicalResult verifyConvOp(T op) {
527527
auto weightType = op.weight().getType().template dyn_cast<RankedTensorType>();
528528

529529
// Must be ranked tensor types
530-
if (!inputType || !weightType)
530+
if (!inputType) {
531+
op.emitOpError("expect a ranked tensor for input, got ") << op.input();
531532
return failure();
533+
}
534+
if (!weightType) {
535+
op.emitOpError("expect a ranked tensor for weight, got ") << op.weight();
536+
return failure();
537+
}
532538

533539
auto inputEType = inputType.getElementType();
534540
auto weightEType = weightType.getElementType();
@@ -537,14 +543,21 @@ template <typename T> static LogicalResult verifyConvOp(T op) {
537543
bool weightIsQuant = !weightEType.template isa<FloatType>();
538544

539545
// Either both must be quantized or both unquantized.
540-
if (inputIsQuant != weightIsQuant)
546+
if (inputIsQuant != weightIsQuant) {
547+
op.emitOpError(
548+
"expect both input and weight to be float or not together, got ")
549+
<< inputEType << " and " << weightEType;
541550
return failure();
551+
}
542552

543553
// Quantized type must have constructed the quantizationattr, and unquantized
544554
// types should not have a quantizationattr.
545555
if ((inputIsQuant && !op.quantization_info()) ||
546-
(!inputIsQuant && op.quantization_info()))
556+
(!inputIsQuant && op.quantization_info())) {
557+
op.emitOpError("quantizationattr is required for quantized type, and not "
558+
"allowed for float type");
547559
return failure();
560+
}
548561

549562
return success();
550563
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
2+
3+
4+
func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
5+
// expected-error@+1 {{expect both input and weight to be float or not together, got 'f32' and 'i8'}}
6+
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]}
7+
: (tensor<1x29x29x4xf32>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
8+
return %0 : tensor<1x27x27x16xi8>
9+
}
10+
11+
// -----
12+
13+
func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
14+
// expected-error@+1 {{expect a ranked tensor for input, got <block argument> of type 'tensor<*xi8>' at index: 0}}
15+
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]}
16+
: (tensor<*xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
17+
return %0 : tensor<1x27x27x16xi8>
18+
}
19+
20+
// -----
21+
22+
func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
23+
// expected-error@+1 {{expect a ranked tensor for weight, got <block argument> of type 'tensor<*xi8>' at index: 1}}
24+
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]}
25+
: (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
26+
return %0 : tensor<1x27x27x16xi8>
27+
}
28+
29+
30+
// -----
31+
32+
func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
33+
// expected-error@+1 {{'tosa.conv2d' op quantizationattr is required for quantized type, and not allowed for float type}}
34+
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]}
35+
: (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
36+
return %0 : tensor<1x27x27x16xi8>
37+
}
38+
39+

0 commit comments

Comments
 (0)