-
Notifications
You must be signed in to change notification settings - Fork 12.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Fix lower_unpack
when dynamic dimensions are involved
#68423
Conversation
When lowering `tensor.unpack`, we need to use the sizes of the destination tensor in the final `tensor.extract_slice` operation. Prior to this patch, when the destination tensor had dynamic dimensions, we would compute them from the result of the `tensor.unpack` operation instead of its destination argument. This would produce invalid IR because the `tensor.dim` operations would need to appear before the `tensor.extract_slice` operation, but the input of the `tensor.dim` operations would consume the final result of the lowering of `tensor.unpack`, which happens after the `tensor.extract_slice` operation. In other words, the definition wouldn't dominate its uses. I.e., we were generating: ``` %dynDim = tensor.dim %defLater, ... <-- %defLater defined below %res = tensor.extract_slice ..., %dynDim, ... %defLater = linalg.copy (ins %res) ``` Note: I checked the implementation of `lower_pack` and the code is correct as far as I can tell.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg ChangesWhen lowering This would produce invalid IR because the I.e., we were generating:
Note: I checked the implementation of Full diff: https://github.com/llvm/llvm-project/pull/68423.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8183b40ad7346f4..bca343cf8777149 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -467,7 +467,7 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
loc, destTensorType, reshapeOp->getResult(0),
SmallVector<OpFoldResult>(destRank, zero),
- tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
+ tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
SmallVector<OpFoldResult>(destRank, one));
// 7. Inject a copy to preserve DPS.
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index c71feddcc1c8486..ad6c6a6f6199cc6 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -133,7 +133,7 @@ func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16
// CHECK-SAME: : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32>
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
// CHECK-SAME: : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
- // CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>)
+ // CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>)
// CHECK-SAME: outs(%[[ARG1]] : tensor<129x47x16x16xf32>)
%pack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
: tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32>
@@ -397,3 +397,40 @@ transform.sequence failures(propagate) {
transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
-> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
}
+
+// -----
+
+// Check that we can lower unpack with dynamic dimensions in the destination.
+// CHECK-LABEL: func.func @unpack_with_dynamic_dest(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<32x2x49x16x16xf32>, %[[ARG1:.*]]: tensor<32x?x?xf32>)
+// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<32x2x16x49x16xf32>
+// CHECK: %[[TRAN:.*]] = linalg.transpose
+// CHECK-SAME: ins(%[[ARG0]] : tensor<32x2x49x16x16xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<32x2x16x49x16xf32>)
+// CHECK-SAME: permutation = [0, 1, 3, 2, 4]
+// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0], [1, 2], [3, 4]]
+// CHECK-SAME: : tensor<32x2x16x49x16xf32> into tensor<32x32x784xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<32x?x?xf32>
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<32x?x?xf32>
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0] [32, %[[DIM1]], %[[DIM2]]] [1, 1, 1]
+// CHECK-SAME: : tensor<32x32x784xf32> to tensor<32x?x?xf32>
+// CHECK: linalg.copy ins(%[[SLICE]] : tensor<32x?x?xf32>)
+// CHECK-SAME: outs(%[[ARG1]] : tensor<32x?x?xf32>)
+func.func @unpack_with_dynamic_dest(%arg0: tensor<32x2x49x16x16xf32>, %arg1: tensor<32x?x?xf32>) -> tensor<32x?x?xf32> {
+ %pack = tensor.unpack %arg0 inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %arg1
+ : tensor<32x2x49x16x16xf32> -> tensor<32x?x?xf32>
+ return %pack : tensor<32x?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+ %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"tensor.unpack">
+ transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+ -> (!transform.op<"tensor.empty">,
+ !transform.op<"linalg.transpose">,
+ !transform.op<"tensor.collapse_shape">,
+ !transform.op<"tensor.extract_slice">)
+}
|
When lowering
tensor.unpack
, we need to use the sizes of the destination tensor in the finaltensor.extract_slice
operation. Prior to this patch, when the destination tensor had dynamic dimensions, we would compute them from the result of thetensor.unpack
operation instead of its destination argument.This would produce invalid IR because the
tensor.dim
operations would need to appear before thetensor.extract_slice
operation, but the input of thetensor.dim
operations would consume the final result of the lowering oftensor.unpack
, which happens after thetensor.extract_slice
operation. In other words, the definition wouldn't dominate its uses.I.e., we were generating:
Note: I checked the implementation of
lower_pack
and the code is correct as far as I can tell.