-
Notifications
You must be signed in to change notification settings - Fork 12.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][memref] Add memref alias folding for masked transfers #71476
Conversation
Because masking of vector.transfer ops semantically apply to the unpermuted input vector (for reads) and permuted output vector (for writes), they apply independently of any subviews.
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: Quinn Dawkins (qedawkins) ChangesBecause masking of vector.transfer ops apply to the unpermuted input vector (for reads) and permuted output vector (for writes), they apply independently of any subviews and can thus be forwarded to the folded transfers. Full diff: https://github.com/llvm/llvm-project/pull/71476.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 043e8fbcdd2f6fb..b78c4510ff88585 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -346,8 +346,6 @@ preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
"must be a vector transfer op");
if (xferOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
- if (xferOp.getMask())
- return rewriter.notifyMatchFailure(xferOp, "masked transfer");
if (!subviewOp.hasUnitStride()) {
return rewriter.notifyMatchFailure(
xferOp, "non-1 stride subview, need to track strides in folded memref");
@@ -428,7 +426,7 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
AffineMapAttr::get(expandDimsToRank(
op.getPermutationMap(), subViewOp.getSourceType().getRank(),
subViewOp.getDroppedDims())),
- op.getPadding(), /*mask=*/Value(), op.getInBoundsAttr());
+ op.getPadding(), op.getMask(), op.getInBoundsAttr());
})
.Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
@@ -557,7 +555,7 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
AffineMapAttr::get(expandDimsToRank(
op.getPermutationMap(), subViewOp.getSourceType().getRank(),
subViewOp.getDroppedDims())),
- op.getInBoundsAttr());
+ op.getMask(), op.getInBoundsAttr());
})
.Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 3f11e22749bb16d..2e1319420fb3eaf 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -266,6 +266,63 @@ func.func @fold_vector_transfer_write_with_inner_rank_reduced_subview(
// -----
+func.func @fold_masked_vector_transfer_read_with_subview(
+ %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
+ %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
+ %arg6 : index, %mask : vector<4xi1>) -> vector<4xf32> {
+ %cst = arith.constant 0.0 : f32
+ %0 = memref.subview %arg0[%arg1, %arg2] [%arg3, %arg4] [1, 1]
+ : memref<?x?xf32, strided<[?, ?], offset: ?>> to
+ memref<?x?xf32, strided<[?, ?], offset: ?>>
+ %1 = vector.transfer_read %0[%arg5, %arg6], %cst, %mask {in_bounds = [true]}
+ : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4xf32>
+ return %1 : vector<4xf32>
+}
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: func @fold_masked_vector_transfer_read_with_subview
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: vector<4xi1>
+// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]], %[[ARG5]]]
+// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
+// CHECK: vector.transfer_read %[[ARG0]][%[[IDX0]], %[[IDX1]]], %{{.*}}, %[[ARG7]] {{.*}} : memref<?x?xf32
+
+// -----
+
+func.func @fold_masked_vector_transfer_write_with_subview(
+ %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
+ %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
+ %arg5: index, %arg6 : index, %arg7 : index, %mask : vector<4xi1>) {
+ %cst = arith.constant 0.0 : f32
+ %0 = memref.subview %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1]
+ : memref<?x?xf32, strided<[?, ?], offset: ?>> to
+ memref<?x?xf32, strided<[?, ?], offset: ?>>
+ vector.transfer_write %arg1, %0[%arg6, %arg7], %mask {in_bounds = [true]}
+ : vector<4xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return
+}
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: func @fold_masked_vector_transfer_write_with_subview
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: vector<4xi1>
+// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
+// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
+// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]]], %[[ARG8]] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32
+
+// -----
+
// Test with affine.load/store ops. We only do a basic test here since the
// logic is identical to that with memref.load/store ops. The same affine.apply
// ops would be generated.
|
A tricky part with subview ops in general is rank reducing behavior. Thinking a bit actually it should be fine given that folding rank reducing subview ops into the original memref would just mean we see extra dims in the transfer op permuation map. That will be compressed when inferTransferOpMaskType, so we can still use the same mask. |
Let's add some rank-reducing tests please |
The contents of a mask on a masked transfer are unaffected by the particular region of memory being read/stored to, so just forward the mask in subview folding patterns.