-
Notifications
You must be signed in to change notification settings - Fork 12.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Add leading unit dim folding patterns for masked transfers #71466
[mlir][vector] Add leading unit dim folding patterns for masked transfers #71466
Conversation
…fers This handles `vector.transfer_read`, `vector.transfer_write`, and `vector.constant_mask`. The unit dims are only relevant for masks created by `create_mask` and `constant_mask` if the mask size for the unit dim is non-one, in which case all subsequent sizes must also be zero. From the perspective of the vector transfers, however, these unit dims can just be dropped directly.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Quinn Dawkins (qedawkins) ChangesThis handles Full diff: https://github.com/llvm/llvm-project/pull/71466.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 6bbb293fa2a6b5c..75f32b23e57b0d6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -6,6 +6,8 @@
//
//===----------------------------------------------------------------------===//
+#include <numeric>
+
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -208,9 +210,6 @@ struct CastAwayTransferReadLeadingOneDim
if (read.getTransferRank() == 0)
return failure();
- if (read.getMask())
- return failure();
-
auto shapedType = cast<ShapedType>(read.getSource().getType());
if (shapedType.getElementType() != read.getVectorType().getElementType())
return failure();
@@ -233,10 +232,18 @@ struct CastAwayTransferReadLeadingOneDim
inBoundsAttr = rewriter.getArrayAttr(
read.getInBoundsAttr().getValue().take_back(newType.getRank()));
+ Value mask = Value();
+ if (read.getMask()) {
+ // The mask shape must always match the shape of the written vector, so we
+ // can safely use the same extraction indices.
+ int64_t dropDim = oldType.getRank() - newType.getRank();
+ mask = rewriter.create<vector::ExtractOp>(read.getLoc(), read.getMask(),
+ splatZero(dropDim));
+ }
+
auto newRead = rewriter.create<vector::TransferReadOp>(
read.getLoc(), newType, read.getSource(), read.getIndices(),
- AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(),
- inBoundsAttr);
+ AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
return success();
@@ -256,9 +263,6 @@ struct CastAwayTransferWriteLeadingOneDim
if (write.getTransferRank() == 0)
return failure();
- if (write.getMask())
- return failure();
-
auto shapedType = dyn_cast<ShapedType>(write.getSource().getType());
if (shapedType.getElementType() != write.getVectorType().getElementType())
return failure();
@@ -283,10 +287,21 @@ struct CastAwayTransferWriteLeadingOneDim
auto newVector = rewriter.create<vector::ExtractOp>(
write.getLoc(), write.getVector(), splatZero(dropDim));
+
+ if (write.getMask()) {
+ // The mask shape must always match the shape of the written vector, so we
+ // can safely use the same extraction indices.
+ auto newMask = rewriter.create<vector::ExtractOp>(
+ write.getLoc(), write.getMask(), splatZero(dropDim));
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ write, newVector, write.getSource(), write.getIndices(),
+ AffineMapAttr::get(newMap), newMask, inBoundsAttr);
+ return success();
+ }
+
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
write, newVector, write.getSource(), write.getIndices(),
AffineMapAttr::get(newMap), inBoundsAttr);
-
return success();
}
};
@@ -467,6 +482,40 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
}
};
+// Drops leading 1 dimensions from vector.constant_mask and inserts a
+// vector.broadcast back to the original shape.
+struct CastAwayConstantMaskLeadingOneDim
+ : public OpRewritePattern<vector::ConstantMaskOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
+ PatternRewriter &rewriter) const override {
+ VectorType oldType = mask.getType();
+ VectorType newType = trimLeadingOneDims(oldType);
+
+ if (newType == oldType)
+ return failure();
+
+ int64_t dropDim = oldType.getRank() - newType.getRank();
+ SmallVector<int64_t> dimSizes;
+ for (auto attr : mask.getMaskDimSizes())
+ dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt());
+
+ // If any of the dropped unit dims has a size of `0`, the entire mask is a
+ // zero mask, else the unit dim has no effect on the mask.
+ int64_t flatLeadingSize =
+ std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1,
+ static_cast<int64_t>(1), std::multiplies<int64_t>());
+ SmallVector<int64_t> newDimSizes({flatLeadingSize});
+ newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
+
+ auto newMask = rewriter.create<vector::ConstantMaskOp>(
+ mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes));
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
@@ -474,7 +523,7 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
patterns
.add<CastAwayExtractStridedSliceLeadingOneDim,
CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
- CastAwayTransferReadLeadingOneDim,
+ CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
populateShapeCastFoldingPatterns(patterns, benefit);
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index e5b27b04dcc8096..5de30206927db2f 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -209,6 +209,20 @@ func.func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>)
return %0: vector<1x4xf16>
}
+// CHECK-LABEL: func @cast_away_masked_transfer_read_leading_one_dims
+func.func @cast_away_masked_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xi1>) -> vector<1x4xf16> {
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
+ %f0 = arith.constant 0. : f16
+ // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+ // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
+ // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
+ %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
+ // CHECK: return %[[CAST]]
+ return %0: vector<1x4xf16>
+}
+
// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element
func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
%c0 = arith.constant 0 : index
@@ -229,6 +243,18 @@ func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>
return
}
+// CHECK-LABEL: func @cast_away_masked_transfer_write_leading_one_dims
+func.func @cast_away_masked_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>, %arg2: vector<1x4xi1>) {
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16>
+ // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+ // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
+
+ vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0], %arg2 {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
+ return
+}
+
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
%c0 = arith.constant 0 : index
@@ -410,3 +436,12 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x
%0 = vector.insert %s, %v [0, 0, 7] : vector<1x[8]xi1> into vector<1x1x8x1x[8]xi1>
return %0: vector<1x1x8x1x[8]xi1>
}
+
+// CHECK-LABEL: func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
+// CHECK: %[[MASK:.*]] = vector.constant_mask [6, 1, 1] : vector<8x2x1xi1>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1>
+// CHECK: return %[[BCAST]] : vector<1x1x8x2x1xi1>
+func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
+ %0 = vector.constant_mask [1, 1, 6, 1, 1] : vector<1x1x8x2x1xi1>
+ return %0: vector<1x1x8x2x1xi1>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks!
This handles
vector.transfer_read
,vector.transfer_write
, andvector.constant_mask
. The unit dims are only relevant for masks created bycreate_mask
andconstant_mask
if the mask size for the unit dim is non-one, in which case all subsequent sizes must also be zero. From the perspective of the vector transfers, however, these unit dims can just be dropped directly.