diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index 043e8fbcdd2f6..b78c4510ff885 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -346,8 +346,6 @@ preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, "must be a vector transfer op"); if (xferOp.hasOutOfBoundsDim()) return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim"); - if (xferOp.getMask()) - return rewriter.notifyMatchFailure(xferOp, "masked transfer"); if (!subviewOp.hasUnitStride()) { return rewriter.notifyMatchFailure( xferOp, "non-1 stride subview, need to track strides in folded memref"); @@ -428,7 +426,7 @@ LogicalResult LoadOpOfSubViewOpFolder::matchAndRewrite( AffineMapAttr::get(expandDimsToRank( op.getPermutationMap(), subViewOp.getSourceType().getRank(), subViewOp.getDroppedDims())), - op.getPadding(), /*mask=*/Value(), op.getInBoundsAttr()); + op.getPadding(), op.getMask(), op.getInBoundsAttr()); }) .Case([&](gpu::SubgroupMmaLoadMatrixOp op) { rewriter.replaceOpWithNewOp( @@ -557,7 +555,7 @@ LogicalResult StoreOpOfSubViewOpFolder::matchAndRewrite( AffineMapAttr::get(expandDimsToRank( op.getPermutationMap(), subViewOp.getSourceType().getRank(), subViewOp.getDroppedDims())), - op.getInBoundsAttr()); + op.getMask(), op.getInBoundsAttr()); }) .Case([&](gpu::SubgroupMmaStoreMatrixOp op) { rewriter.replaceOpWithNewOp( diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index 3f11e22749bb1..8fe87bc8c57c3 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -266,6 +266,127 @@ func.func @fold_vector_transfer_write_with_inner_rank_reduced_subview( // ----- +func.func @fold_masked_vector_transfer_read_with_subview( + %arg0 : memref>, + %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index, + %arg6 : index, %mask : vector<4xi1>) -> vector<4xf32> { + %cst = arith.constant 0.0 : f32 + %0 = memref.subview %arg0[%arg1, %arg2] [%arg3, %arg4] [1, 1] + : memref> to + memref> + %1 = vector.transfer_read %0[%arg5, %arg6], %cst, %mask {in_bounds = [true]} + : memref>, vector<4xf32> + return %1 : vector<4xf32> +} +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK: func @fold_masked_vector_transfer_read_with_subview +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4xi1> +// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]], %[[ARG5]]] +// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]] +// CHECK: vector.transfer_read %[[ARG0]][%[[IDX0]], %[[IDX1]]], %{{.*}}, %[[MASK]] {{.*}} : memref>, + %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index, + %arg6 : index, %mask : vector<4x3xi1>) -> vector<3x4xf32> { + %cst = arith.constant 0.0 : f32 + %0 = memref.subview %arg0[0, %arg1, 0, %arg2] [1, %arg3, 1, %arg4] [1, 1, 1, 1] + : memref> to + memref> + %1 = vector.transfer_read %0[%arg5, %arg6], %cst, %mask { + permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} + : memref>, vector<3x4xf32> + return %1 : vector<3x4xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d1)> +// CHECK: func @fold_masked_vector_transfer_read_with_rank_reducing_subview +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4x3xi1> +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[PAD:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[ARG5]]] +// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]], %[[ARG6]]] +// CHECK: vector.transfer_read %[[ARG0]][%[[C0]], %[[IDX0]], %[[C0]], %[[IDX1]]], %[[PAD]], %[[MASK]] {{.*}} permutation_map = #[[MAP1]]} : memref>, + %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index, + %arg5: index, %arg6 : index, %arg7 : index, %mask : vector<4xi1>) { + %cst = arith.constant 0.0 : f32 + %0 = memref.subview %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] + : memref> to + memref> + vector.transfer_write %arg1, %0[%arg6, %arg7], %mask {in_bounds = [true]} + : vector<4xf32>, memref> + return +} +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK: func @fold_masked_vector_transfer_write_with_subview +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4xi1> +// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]] +// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]] +// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]]], %[[MASK]] {in_bounds = [true]} : vector<4xf32>, memref>, + %arg1 : vector<3x4xf32>, %arg2: index, %arg3 : index, %arg4 : index, + %arg5: index, %arg6 : index, %arg7 : index, %mask : vector<4x3xi1>) { + %cst = arith.constant 0.0 : f32 + %0 = memref.subview %arg0[0, %arg2, 0, %arg3] [1, %arg4, 1, %arg5] [1, 1, 1, 1] + : memref> to + memref> + vector.transfer_write %arg1, %0[%arg6, %arg7], %mask { + permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} + : vector<3x4xf32>, memref> + return +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d1)> +// CHECK: func @fold_masked_vector_transfer_write_with_rank_reducing_subview +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<3x4xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4x3xi1> +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]], %[[ARG6]]] +// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP0]]()[%[[ARG3]], %[[ARG7]]] +// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[IDX0]], %[[C0]], %[[IDX1]]], %[[ARG8]] {in_bounds = [true, true], permutation_map = #[[MAP1]]} : vector<3x4xf32>, memref