diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 393f73dc65cd8..78201ae29cd9b 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2353,6 +2353,16 @@ LogicalResult ExpandShapeOp::verify() { << " dynamic dims while output_shape has " << getOutputShape().size() << " values"; + // Verify if provided output shapes are in agreement with output type. + DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr(); + ArrayRef resShape = getResult().getType().getShape(); + unsigned staticShapeNum = 0; + + for (auto [pos, shape] : llvm::enumerate(resShape)) + if (!ShapedType::isDynamic(shape) && + shape != staticOutputShapes[staticShapeNum++]) + emitOpError("invalid output shape provided at pos ") << pos; + return success(); } diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index 70c96aad9555e..0f533cb95a0ca 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -1103,3 +1103,14 @@ func.func @subview_invalid_strides_rank_reduction(%m: memref<7x22x333x4444xi32>) : memref<7x22x333x4444xi32> to memref<7x11x4444xi32> return } + +// ----- + +func.func @expand_shape_invalid_output_shape( + %arg0: memref<30x20xf32, strided<[4000, 2], offset: 100>>) { + // expected-error @+1 {{invalid output shape provided at pos 2}} + %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 15, 21] : + memref<30x20xf32, strided<[4000, 2], offset: 100>> + into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>> + return +}