diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 8183b40ad7346f..bca343cf877714 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -467,7 +467,7 @@ FailureOr linalg::lowerUnPack(RewriterBase &rewriter, auto extractSliceOp = rewriter.create( loc, destTensorType, reshapeOp->getResult(0), SmallVector(destRank, zero), - tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)), + tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()), SmallVector(destRank, one)); // 7. Inject a copy to preserve DPS. diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir index c71feddcc1c848..ad6c6a6f6199cc 100644 --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -133,7 +133,7 @@ func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16 // CHECK-SAME: : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32> // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1] // CHECK-SAME: : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32> - // CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>) + // CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>) // CHECK-SAME: outs(%[[ARG1]] : tensor<129x47x16x16xf32>) %pack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1 : tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32> @@ -397,3 +397,40 @@ transform.sequence failures(propagate) { transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">) -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">) } + +// ----- + +// Check that we can lower unpack with dynamic dimensions in the destination. +// CHECK-LABEL: func.func @unpack_with_dynamic_dest( +// CHECK-SAME: %[[ARG0:.*]]: tensor<32x2x49x16x16xf32>, %[[ARG1:.*]]: tensor<32x?x?xf32>) +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<32x2x16x49x16xf32> +// CHECK: %[[TRAN:.*]] = linalg.transpose +// CHECK-SAME: ins(%[[ARG0]] : tensor<32x2x49x16x16xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<32x2x16x49x16xf32>) +// CHECK-SAME: permutation = [0, 1, 3, 2, 4] +// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0], [1, 2], [3, 4]] +// CHECK-SAME: : tensor<32x2x16x49x16xf32> into tensor<32x32x784xf32> +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<32x?x?xf32> +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<32x?x?xf32> +// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0] [32, %[[DIM1]], %[[DIM2]]] [1, 1, 1] +// CHECK-SAME: : tensor<32x32x784xf32> to tensor<32x?x?xf32> +// CHECK: linalg.copy ins(%[[SLICE]] : tensor<32x?x?xf32>) +// CHECK-SAME: outs(%[[ARG1]] : tensor<32x?x?xf32>) +func.func @unpack_with_dynamic_dest(%arg0: tensor<32x2x49x16x16xf32>, %arg1: tensor<32x?x?xf32>) -> tensor<32x?x?xf32> { + %pack = tensor.unpack %arg0 inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %arg1 + : tensor<32x2x49x16x16xf32> -> tensor<32x?x?xf32> + return %pack : tensor<32x?x?xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !transform.any_op): + %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op + : (!transform.any_op) -> !transform.op<"tensor.unpack"> + transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">) + -> (!transform.op<"tensor.empty">, + !transform.op<"linalg.transpose">, + !transform.op<"tensor.collapse_shape">, + !transform.op<"tensor.extract_slice">) +}