Skip to content

Commit

Permalink
Updates to 'tosa.reshape' verifier (#87416)
Browse files Browse the repository at this point in the history
This addition catches common cases of malformed `tosa.reshape` ops. This
prevents the `--tosa-to-tensor` pass from asserting when fed invalid
operations, as these will be caught ahead of time by the verifier.

Closes #87396
  • Loading branch information
rafaelubalmw authored Apr 3, 2024
1 parent cd29126 commit fbcd0c6
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 17 deletions.
17 changes: 13 additions & 4 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -955,25 +955,34 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
}

mlir::LogicalResult tosa::ReshapeOp::verify() {
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
ShapedType outputType = llvm::cast<ShapedType>(getType());
TensorType inputType = getInput1().getType();
RankedTensorType outputType = getType();

if (hasZeroDimension(inputType) || hasZeroDimension(outputType))
return emitOpError() << "tensor has a dimension with size zero. Each "
"dimension of a tensor must have size >= 1";

if ((int64_t) getNewShape().size() != outputType.getRank())
return emitOpError() << "new shape does not match result rank";

for (auto [newShapeDim, outputShapeDim] :
zip(getNewShape(), outputType.getShape()))
if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
newShapeDim != outputShapeDim)
return emitOpError() << "new shape is inconsistent with result shape";

if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
int64_t inputElementsNum = inputType.getNumElements();
int64_t outputElementsNum = outputType.getNumElements();
if (inputElementsNum != outputElementsNum) {
return emitOpError() << "Cannot reshape " << inputElementsNum
return emitOpError() << "cannot reshape " << inputElementsNum
<< " elements into " << outputElementsNum;
}
}

int missingDims = llvm::count(getNewShape(), -1);
if (missingDims > 1)
return emitOpError() << "At most one target dimension can be -1";
return emitOpError() << "expected at most one target dimension to be -1";

return mlir::success();
}
Expand Down
58 changes: 45 additions & 13 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -243,38 +243,70 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {

// -----

func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
// expected-error@+1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
%0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () {
// expected-error@+1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<13x0x3xf32>) -> tensor<13x0x3xf32>
return
}

// -----

func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
// expected-error@+1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}}
%0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32>
return %0 : tensor<100x100xf32>
func.func @test_reshape_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
// expected-error@+1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<?x0x3xf32>) -> tensor<13x0x3xf32>
return
}

// -----

func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () {
// expected-error@+1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<13x0x3xf32>) -> tensor<13x0x3xf32>
func.func @test_reshape_rank_mismatch(%arg0 : tensor<?xf32>) -> () {
// expected-error@+1 {{'tosa.reshape' op new shape does not match result rank}}
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 4>} : (tensor<?xf32>) -> tensor<?xf32>
return
}

// -----

func.func @test_reshape_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
// expected-error@+1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<?x0x3xf32>) -> tensor<13x0x3xf32>
func.func @test_reshape_inconsistent_result_type(%arg0 : tensor<?xf32>) -> () {
// expected-error@+1 {{'tosa.reshape' op new shape is inconsistent with result shape}}
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 4, -1>} : (tensor<?xf32>) -> tensor<?x3x5xf32>
return
}

// -----

func.func @test_reshape_invalid_size(%arg0 : tensor<2x4xf32>) -> () {
// expected-error@+1 {{'tosa.reshape' op cannot reshape 8 elements into 15}}
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 3, 5>} : (tensor<2x4xf32>) -> tensor<3x5xf32>
return
}

// -----

func.func @test_reshape_invalid_placeholders(%arg0 : tensor<?xf32>) -> () {
// expected-error@+1 {{'tosa.reshape' op expected at most one target dimension to be -1}}
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, -1>} : (tensor<?xf32>) -> tensor<2x?x?xf32>
return
}

// -----

func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
// expected-error@+1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
%0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
return
}

// -----

func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
// expected-error@+1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}}
%0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32>
return %0 : tensor<100x100xf32>
}

// -----

func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
// expected-error@+1 {{'tosa.conv2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
Expand Down

0 comments on commit fbcd0c6

Please sign in to comment.