-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Add unroll patterns for vector.load and vector.store #143420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Nishant Patel (nbpatel) ChangesThis PR adds unroll patterns for vector.load and vector.store with rank > 1 and unrolls them to 1D load and store. This PR is follow up of #137558 Full diff: https://github.com/llvm/llvm-project/pull/143420.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fc443ab0d138e..e912a6ef29b21 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -54,6 +54,33 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
return slicedIndices;
}
+// compute the new indices for vector.load/store by adding offsets to
+// originalIndices.
+// It assumes m <= n (m = offsets.size(), n = originalIndices.size())
+// Last m of originalIndices will be updated.
+static SmallVector<Value> computeIndices(PatternRewriter &rewriter,
+ Location loc,
+ ArrayRef<Value> originalIndices,
+ ArrayRef<int64_t> offsets) {
+ assert(offsets.size() <= originalIndices.size() &&
+ "Offsets should not exceed the number of original indices");
+ SmallVector<Value> indices(originalIndices);
+ auto originalIter = originalIndices.rbegin();
+ auto offsetsIter = offsets.rbegin();
+ auto indicesIter = indices.rbegin();
+ while (offsetsIter != offsets.rend()) {
+ Value original = *originalIter;
+ int64_t offset = *offsetsIter;
+ if (offset != 0)
+ *indicesIter = rewriter.create<arith::AddIOp>(
+ loc, original, rewriter.create<arith::ConstantIndexOp>(loc, offset));
+ originalIter++;
+ offsetsIter++;
+ indicesIter++;
+ }
+ return indices;
+};
+
// Clones `op` into a new operations that takes `operands` and returns
// `resultTypes`.
static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
@@ -631,6 +658,98 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
vector::UnrollVectorOptions options;
};
+struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
+ UnrollLoadPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
+
+ LogicalResult matchAndRewrite(vector::LoadOp loadOp,
+ PatternRewriter &rewriter) const override {
+ VectorType vecType = loadOp.getVectorType();
+ // Only unroll >1D loads
+ if (vecType.getRank() <= 1)
+ return failure();
+
+ Location loc = loadOp.getLoc();
+ ArrayRef<int64_t> originalShape = vecType.getShape();
+
+ // Target type is a 1D vector of the innermost dimension.
+ auto targetType =
+ VectorType::get(originalShape.back(), vecType.getElementType());
+
+ // Extend the targetShape to the same rank of original shape by padding 1s
+ // for leading dimensions for convenience of computing offsets
+ SmallVector<int64_t> targetShape(originalShape.size(), 1);
+ targetShape.back() = originalShape.back();
+
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, vecType, rewriter.getZeroAttr(vecType));
+
+ SmallVector<Value> originalIndices(loadOp.getIndices().begin(),
+ loadOp.getIndices().end());
+
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, targetShape)) {
+ SmallVector<Value> indices =
+ computeIndices(rewriter, loc, originalIndices, offsets);
+ Value slice = rewriter.create<vector::LoadOp>(loc, targetType,
+ loadOp.getBase(), indices);
+ // Insert the slice into the result at the correct position.
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, slice, result, offsets, SmallVector<int64_t>({1}));
+ }
+ rewriter.replaceOp(loadOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
+struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
+ UnrollStorePattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
+
+ LogicalResult matchAndRewrite(vector::StoreOp storeOp,
+ PatternRewriter &rewriter) const override {
+ VectorType vecType = storeOp.getVectorType();
+ // Only unroll >1D stores.
+ if (vecType.getRank() <= 1)
+ return failure();
+
+ Location loc = storeOp.getLoc();
+ ArrayRef<int64_t> originalShape = vecType.getShape();
+
+ // Extend the targetShape to the same rank of original shape by padding 1s
+ // for leading dimensions for convenience of computing offsets
+ SmallVector<int64_t> targetShape(originalShape.size(), 1);
+ targetShape.back() = originalShape.back();
+
+ Value base = storeOp.getBase();
+ Value vector = storeOp.getValueToStore();
+
+ SmallVector<Value> originalIndices(storeOp.getIndices().begin(),
+ storeOp.getIndices().end());
+
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, targetShape)) {
+ SmallVector<Value> indices =
+ computeIndices(rewriter, loc, originalIndices, offsets);
+ offsets.pop_back();
+ Value slice = rewriter.create<vector::ExtractOp>(loc, vector, offsets);
+ rewriter.create<vector::StoreOp>(loc, slice, base, indices);
+ }
+ rewriter.eraseOp(storeOp);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
UnrollBroadcastPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
@@ -699,10 +818,10 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options,
PatternBenefit benefit) {
- patterns
- .add<UnrollTransferReadPattern, UnrollTransferWritePattern,
- UnrollContractionPattern, UnrollElementwisePattern,
- UnrollReductionPattern, UnrollMultiReductionPattern,
- UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>(
- patterns.getContext(), options, benefit);
+ patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
+ UnrollContractionPattern, UnrollElementwisePattern,
+ UnrollReductionPattern, UnrollMultiReductionPattern,
+ UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
+ UnrollStorePattern, UnrollBroadcastPattern>(
+ patterns.getContext(), options, benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
new file mode 100644
index 0000000000000..3135268b8d61b
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -test-vector-load-store-unroll --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @unroll_2D_vector_load(
+// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
+func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16>
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: return %[[V7]] : vector<4x4xf16>
+ %c0 = arith.constant 0 : index
+ %0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+ return %0 : vector<4x4xf16>
+}
+
+// CHECK-LABEL: func.func @unroll_2D_vector_store(
+// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
+func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ %c0 = arith.constant 0 : index
+ vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+ return
+}
+
+// CHECK-LABEL: func.func @unroll_vector_load(
+// CHECK-SAME: %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
+ // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
+ // CHECK: return %[[V3]] : vector<2x2xf16>
+ %c1 = arith.constant 1 : index
+ %0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+ return %0 : vector<2x2xf16>
+}
+
+// CHECK-LABEL: func.func @unroll_vector_store(
+// CHECK-SAME: %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) {
+func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) {
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16>
+ // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16>
+ // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ %c1 = arith.constant 1 : index
+ vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+ return
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 54aa96ba89a00..8014362a1a6ec 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -292,6 +292,44 @@ struct TestVectorTransferUnrollingPatterns
llvm::cl::init(false)};
};
+struct TestVectorLoadStoreUnrollPatterns
+ : public PassWrapper<TestVectorLoadStoreUnrollPatterns,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorLoadStoreUnrollPatterns)
+
+ StringRef getArgument() const final {
+ return "test-vector-load-store-unroll";
+ }
+ StringRef getDescription() const final {
+ return "Test unrolling patterns for vector.load and vector.store ops";
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect, arith::ArithDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+
+ // Unroll all vector.load and vector.store ops with rank > 1 to 1D vectors
+ vector::UnrollVectorOptions options;
+ options.setFilterConstraint([](Operation *op) {
+ if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+ return success(loadOp.getType().getRank() > 1);
+ if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+ return success(storeOp.getVectorType().getRank() > 1);
+ return failure();
+ });
+
+ vector::populateVectorUnrollPatterns(patterns, options);
+
+ // Apply the patterns
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestScalarVectorTransferLoweringPatterns
: public PassWrapper<TestScalarVectorTransferLoweringPatterns,
OperationPass<func::FuncOp>> {
@@ -1032,6 +1070,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorTransferUnrollingPatterns>();
+ PassRegistration<TestVectorLoadStoreUnrollPatterns>();
+
PassRegistration<TestScalarVectorTransferLoweringPatterns>();
PassRegistration<TestVectorTransferOpt>();
|
@llvm/pr-subscribers-mlir-vector Author: Nishant Patel (nbpatel) ChangesThis PR adds unroll patterns for vector.load and vector.store with rank > 1 and unrolls them to 1D load and store. This PR is follow up of #137558 Full diff: https://github.com/llvm/llvm-project/pull/143420.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fc443ab0d138e..e912a6ef29b21 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -54,6 +54,33 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
return slicedIndices;
}
+// compute the new indices for vector.load/store by adding offsets to
+// originalIndices.
+// It assumes m <= n (m = offsets.size(), n = originalIndices.size())
+// Last m of originalIndices will be updated.
+static SmallVector<Value> computeIndices(PatternRewriter &rewriter,
+ Location loc,
+ ArrayRef<Value> originalIndices,
+ ArrayRef<int64_t> offsets) {
+ assert(offsets.size() <= originalIndices.size() &&
+ "Offsets should not exceed the number of original indices");
+ SmallVector<Value> indices(originalIndices);
+ auto originalIter = originalIndices.rbegin();
+ auto offsetsIter = offsets.rbegin();
+ auto indicesIter = indices.rbegin();
+ while (offsetsIter != offsets.rend()) {
+ Value original = *originalIter;
+ int64_t offset = *offsetsIter;
+ if (offset != 0)
+ *indicesIter = rewriter.create<arith::AddIOp>(
+ loc, original, rewriter.create<arith::ConstantIndexOp>(loc, offset));
+ originalIter++;
+ offsetsIter++;
+ indicesIter++;
+ }
+ return indices;
+};
+
// Clones `op` into a new operations that takes `operands` and returns
// `resultTypes`.
static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
@@ -631,6 +658,98 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
vector::UnrollVectorOptions options;
};
+struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
+ UnrollLoadPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
+
+ LogicalResult matchAndRewrite(vector::LoadOp loadOp,
+ PatternRewriter &rewriter) const override {
+ VectorType vecType = loadOp.getVectorType();
+ // Only unroll >1D loads
+ if (vecType.getRank() <= 1)
+ return failure();
+
+ Location loc = loadOp.getLoc();
+ ArrayRef<int64_t> originalShape = vecType.getShape();
+
+ // Target type is a 1D vector of the innermost dimension.
+ auto targetType =
+ VectorType::get(originalShape.back(), vecType.getElementType());
+
+ // Extend the targetShape to the same rank of original shape by padding 1s
+ // for leading dimensions for convenience of computing offsets
+ SmallVector<int64_t> targetShape(originalShape.size(), 1);
+ targetShape.back() = originalShape.back();
+
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, vecType, rewriter.getZeroAttr(vecType));
+
+ SmallVector<Value> originalIndices(loadOp.getIndices().begin(),
+ loadOp.getIndices().end());
+
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, targetShape)) {
+ SmallVector<Value> indices =
+ computeIndices(rewriter, loc, originalIndices, offsets);
+ Value slice = rewriter.create<vector::LoadOp>(loc, targetType,
+ loadOp.getBase(), indices);
+ // Insert the slice into the result at the correct position.
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, slice, result, offsets, SmallVector<int64_t>({1}));
+ }
+ rewriter.replaceOp(loadOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
+struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
+ UnrollStorePattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
+
+ LogicalResult matchAndRewrite(vector::StoreOp storeOp,
+ PatternRewriter &rewriter) const override {
+ VectorType vecType = storeOp.getVectorType();
+ // Only unroll >1D stores.
+ if (vecType.getRank() <= 1)
+ return failure();
+
+ Location loc = storeOp.getLoc();
+ ArrayRef<int64_t> originalShape = vecType.getShape();
+
+ // Extend the targetShape to the same rank of original shape by padding 1s
+ // for leading dimensions for convenience of computing offsets
+ SmallVector<int64_t> targetShape(originalShape.size(), 1);
+ targetShape.back() = originalShape.back();
+
+ Value base = storeOp.getBase();
+ Value vector = storeOp.getValueToStore();
+
+ SmallVector<Value> originalIndices(storeOp.getIndices().begin(),
+ storeOp.getIndices().end());
+
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, targetShape)) {
+ SmallVector<Value> indices =
+ computeIndices(rewriter, loc, originalIndices, offsets);
+ offsets.pop_back();
+ Value slice = rewriter.create<vector::ExtractOp>(loc, vector, offsets);
+ rewriter.create<vector::StoreOp>(loc, slice, base, indices);
+ }
+ rewriter.eraseOp(storeOp);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
UnrollBroadcastPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
@@ -699,10 +818,10 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options,
PatternBenefit benefit) {
- patterns
- .add<UnrollTransferReadPattern, UnrollTransferWritePattern,
- UnrollContractionPattern, UnrollElementwisePattern,
- UnrollReductionPattern, UnrollMultiReductionPattern,
- UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>(
- patterns.getContext(), options, benefit);
+ patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
+ UnrollContractionPattern, UnrollElementwisePattern,
+ UnrollReductionPattern, UnrollMultiReductionPattern,
+ UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
+ UnrollStorePattern, UnrollBroadcastPattern>(
+ patterns.getContext(), options, benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
new file mode 100644
index 0000000000000..3135268b8d61b
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -test-vector-load-store-unroll --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @unroll_2D_vector_load(
+// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
+func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16>
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: return %[[V7]] : vector<4x4xf16>
+ %c0 = arith.constant 0 : index
+ %0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+ return %0 : vector<4x4xf16>
+}
+
+// CHECK-LABEL: func.func @unroll_2D_vector_store(
+// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
+func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ %c0 = arith.constant 0 : index
+ vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+ return
+}
+
+// CHECK-LABEL: func.func @unroll_vector_load(
+// CHECK-SAME: %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
+ // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
+ // CHECK: return %[[V3]] : vector<2x2xf16>
+ %c1 = arith.constant 1 : index
+ %0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+ return %0 : vector<2x2xf16>
+}
+
+// CHECK-LABEL: func.func @unroll_vector_store(
+// CHECK-SAME: %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) {
+func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) {
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16>
+ // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16>
+ // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ %c1 = arith.constant 1 : index
+ vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+ return
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 54aa96ba89a00..8014362a1a6ec 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -292,6 +292,44 @@ struct TestVectorTransferUnrollingPatterns
llvm::cl::init(false)};
};
+struct TestVectorLoadStoreUnrollPatterns
+ : public PassWrapper<TestVectorLoadStoreUnrollPatterns,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorLoadStoreUnrollPatterns)
+
+ StringRef getArgument() const final {
+ return "test-vector-load-store-unroll";
+ }
+ StringRef getDescription() const final {
+ return "Test unrolling patterns for vector.load and vector.store ops";
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect, arith::ArithDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+
+ // Unroll all vector.load and vector.store ops with rank > 1 to 1D vectors
+ vector::UnrollVectorOptions options;
+ options.setFilterConstraint([](Operation *op) {
+ if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+ return success(loadOp.getType().getRank() > 1);
+ if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+ return success(storeOp.getVectorType().getRank() > 1);
+ return failure();
+ });
+
+ vector::populateVectorUnrollPatterns(patterns, options);
+
+ // Apply the patterns
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestScalarVectorTransferLoweringPatterns
: public PassWrapper<TestScalarVectorTransferLoweringPatterns,
OperationPass<func::FuncOp>> {
@@ -1032,6 +1070,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorTransferUnrollingPatterns>();
+ PassRegistration<TestVectorLoadStoreUnrollPatterns>();
+
PassRegistration<TestScalarVectorTransferLoweringPatterns>();
PassRegistration<TestVectorTransferOpt>();
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG % some minor asks, thanks!
Thanks @banach-space for the feedback. I addressed your feedback. Please take a look again whenever you have some time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks for the update!
On second scan, it feels that computeIndices
could be simplified. Please take a look. Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for moving this from linearization to unrolling, i.e. from #136193 and #137558 As far as I can tell the existing 9 unroll patterns support unrolling to any tile size of the vector. See for example elementwise:
transfer_read:
and the tests in the file you added your tests to in this PR. Can you motivate deviating from this design? For example, I would expect based on the other unroll patterns for provide a 'target shape' of 2x2 to convert a load of 6x6 to 9 loads of 2x2. I think that transfer_read is a more complex op than vector load. For example, transfer_read is lowered to load here. In terms of achieving what is I think your goad of only having rank-1 loads, I propose the following
Inn my opinion, having a linearization pattern for 2 above is fine, because it converts one rank-2 op into one rank-1 op. That's what I was kind of driving at when I added this comment: llvm-project/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h Line 411 in a53003f
I acknowledge that end result is exactly what you had in #136193 !! But I think it makes Vector more modular and intuitive for other users if unrolling and linearization have a clear boundary. |
@@ -178,6 +178,16 @@ struct TestVectorUnrollingPatterns | |||
return success(isa<vector::TransposeOp>(op)); | |||
})); | |||
|
|||
populateVectorUnrollPatterns( | |||
patterns, UnrollVectorOptions() | |||
.setNativeShape(ArrayRef<int64_t>{2, 2}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think {2,2} should be used
The motivation for this unrolling was to have 1D loads and stores and hence the target shape is set accordingly. But, I agree with your point of making it more modular to align with vector dialect's flexible unrolling design. Can I address this in a follow up PR? |
The 2 designs seem orthogonal, so IMO it doesn't makes sense to do it as a follow up. |
Hi @newling , I changed it to follow similar design as the other patterns. Can you take a look? |
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @nbpatel, this looks good to me! My remaining comments are all quite minor
Thanks, I addressed your comments. Can you please approve if it looks ok? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…lvm#143420) This PR adds unroll patterns for vector.load and vector.store. This PR is follow up of llvm#137558
This PR adds unroll patterns for vector.load and vector.store with rank > 1. This PR is follow up of #137558