diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index e6ba6e6bc602d..8ad8e41414656 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -2324,7 +2324,10 @@ struct RFFT2dConverter final : public OpRewritePattern { auto loc = rfft2d.getLoc(); auto input = rfft2d.getInput(); auto elementType = - cast(cast(input.getType()).getElementType()); + dyn_cast(cast(input.getType()).getElementType()); + if (!elementType) + return rewriter.notifyMatchFailure(rfft2d, + "only supports float element types"); // Compute the output type and set of dynamic sizes llvm::SmallVector dynamicSizes; diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir index ad65410e635e9..b78577275a52a 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir @@ -27,3 +27,12 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, % %2 = tosa.reshape %0 {new_shape = array} : (tensor<*xf32>) -> tensor<10x10xf32> return %2 : tensor<10x10xf32> } + +// ----- + +// CHECK-LABEL: @rfft2d_with_non_float_type +func.func @rfft2d_with_non_float_type(%arg0 : tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>) { + // expected-error@+1 {{failed to legalize operation 'tosa.rfft2d'}} + %real, %imag = tosa.rfft2d %arg0 : (tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>) + return %real, %imag : tensor<1x1x1xi32>, tensor<1x1x1xi32> +}