diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 78201ae29cd9b..c9a85919ec799 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2356,12 +2356,11 @@ LogicalResult ExpandShapeOp::verify() { // Verify if provided output shapes are in agreement with output type. DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr(); ArrayRef resShape = getResult().getType().getShape(); - unsigned staticShapeNum = 0; - - for (auto [pos, shape] : llvm::enumerate(resShape)) - if (!ShapedType::isDynamic(shape) && - shape != staticOutputShapes[staticShapeNum++]) - emitOpError("invalid output shape provided at pos ") << pos; + for (auto [pos, shape] : llvm::enumerate(resShape)) { + if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) { + return emitOpError("invalid output shape provided at pos ") << pos; + } + } return success(); } diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index 99b5f78b03fba..e49dff44ae0d6 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -502,7 +502,7 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16 // CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) { %subview = memref.subview %alloc[%c5, 0] [%c10, 16] [1, 1] : memref<2048x16xf32> to memref> - %expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] output_shape [1, 16, %sz0, 1] : memref> into memref> + %expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] output_shape [%sz0, 1, 8, 2] : memref> into memref> %dim = memref.dim %expand_shape, %c0 : memref> affine.for %arg6 = 0 to %dim step 64 { diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir index 60fb0ffeee240..b60894377f22f 100644 --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -203,7 +203,8 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref, %arg3: memref>, %arg4: index, %arg5: index, - %arg6: index) { + %arg6: index, + %arg7: memref<4x?x4xf32>) { // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] // CHECK-SAME: memref into memref %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : @@ -248,6 +249,10 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref, // CHECK-SAME: memref> into memref %r3 = memref.expand_shape %3 [[0, 1]] output_shape [%arg6, 42] : memref> into memref + +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] + %4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2] + : memref<4x?x4xf32> into memref<2x2x?x2x2xf32> return }