Skip to content

[mlir][vector] Add unroll patterns for vector.load and vector.store #143420

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 20, 2025

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Jun 9, 2025

This PR adds unroll patterns for vector.load and vector.store with rank > 1. This PR is follow up of #137558

@llvmbot
Copy link
Member

llvmbot commented Jun 9, 2025

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

This PR adds unroll patterns for vector.load and vector.store with rank > 1 and unrolls them to 1D load and store. This PR is follow up of #137558


Full diff: https://github.com/llvm/llvm-project/pull/143420.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+125-6)
  • (added) mlir/test/Dialect/Vector/vector-load-store-unroll.mlir (+73)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+40)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fc443ab0d138e..e912a6ef29b21 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -54,6 +54,33 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
   return slicedIndices;
 }
 
+// compute the new indices for vector.load/store by adding offsets to
+// originalIndices.
+// It assumes m <= n (m = offsets.size(), n = originalIndices.size())
+// Last m of originalIndices will be updated.
+static SmallVector<Value> computeIndices(PatternRewriter &rewriter,
+                                         Location loc,
+                                         ArrayRef<Value> originalIndices,
+                                         ArrayRef<int64_t> offsets) {
+  assert(offsets.size() <= originalIndices.size() &&
+         "Offsets should not exceed the number of original indices");
+  SmallVector<Value> indices(originalIndices);
+  auto originalIter = originalIndices.rbegin();
+  auto offsetsIter = offsets.rbegin();
+  auto indicesIter = indices.rbegin();
+  while (offsetsIter != offsets.rend()) {
+    Value original = *originalIter;
+    int64_t offset = *offsetsIter;
+    if (offset != 0)
+      *indicesIter = rewriter.create<arith::AddIOp>(
+          loc, original, rewriter.create<arith::ConstantIndexOp>(loc, offset));
+    originalIter++;
+    offsetsIter++;
+    indicesIter++;
+  }
+  return indices;
+};
+
 // Clones `op` into a new operations that takes `operands` and returns
 // `resultTypes`.
 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
@@ -631,6 +658,98 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
   vector::UnrollVectorOptions options;
 };
 
+struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
+  UnrollLoadPattern(MLIRContext *context,
+                    const vector::UnrollVectorOptions &options,
+                    PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
+
+  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType vecType = loadOp.getVectorType();
+    // Only unroll >1D loads
+    if (vecType.getRank() <= 1)
+      return failure();
+
+    Location loc = loadOp.getLoc();
+    ArrayRef<int64_t> originalShape = vecType.getShape();
+
+    // Target type is a 1D vector of the innermost dimension.
+    auto targetType =
+        VectorType::get(originalShape.back(), vecType.getElementType());
+
+    // Extend the targetShape to the same rank of original shape by padding 1s
+    // for leading dimensions for convenience of computing offsets
+    SmallVector<int64_t> targetShape(originalShape.size(), 1);
+    targetShape.back() = originalShape.back();
+
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, vecType, rewriter.getZeroAttr(vecType));
+
+    SmallVector<Value> originalIndices(loadOp.getIndices().begin(),
+                                       loadOp.getIndices().end());
+
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalShape, targetShape)) {
+      SmallVector<Value> indices =
+          computeIndices(rewriter, loc, originalIndices, offsets);
+      Value slice = rewriter.create<vector::LoadOp>(loc, targetType,
+                                                    loadOp.getBase(), indices);
+      // Insert the slice into the result at the correct position.
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, slice, result, offsets, SmallVector<int64_t>({1}));
+    }
+    rewriter.replaceOp(loadOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
+struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
+  UnrollStorePattern(MLIRContext *context,
+                     const vector::UnrollVectorOptions &options,
+                     PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
+
+  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType vecType = storeOp.getVectorType();
+    // Only unroll >1D stores.
+    if (vecType.getRank() <= 1)
+      return failure();
+
+    Location loc = storeOp.getLoc();
+    ArrayRef<int64_t> originalShape = vecType.getShape();
+
+    // Extend the targetShape to the same rank of original shape by padding 1s
+    // for leading dimensions for convenience of computing offsets
+    SmallVector<int64_t> targetShape(originalShape.size(), 1);
+    targetShape.back() = originalShape.back();
+
+    Value base = storeOp.getBase();
+    Value vector = storeOp.getValueToStore();
+
+    SmallVector<Value> originalIndices(storeOp.getIndices().begin(),
+                                       storeOp.getIndices().end());
+
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalShape, targetShape)) {
+      SmallVector<Value> indices =
+          computeIndices(rewriter, loc, originalIndices, offsets);
+      offsets.pop_back();
+      Value slice = rewriter.create<vector::ExtractOp>(loc, vector, offsets);
+      rewriter.create<vector::StoreOp>(loc, slice, base, indices);
+    }
+    rewriter.eraseOp(storeOp);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
   UnrollBroadcastPattern(MLIRContext *context,
                          const vector::UnrollVectorOptions &options,
@@ -699,10 +818,10 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
 void mlir::vector::populateVectorUnrollPatterns(
     RewritePatternSet &patterns, const UnrollVectorOptions &options,
     PatternBenefit benefit) {
-  patterns
-      .add<UnrollTransferReadPattern, UnrollTransferWritePattern,
-           UnrollContractionPattern, UnrollElementwisePattern,
-           UnrollReductionPattern, UnrollMultiReductionPattern,
-           UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>(
-          patterns.getContext(), options, benefit);
+  patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
+               UnrollContractionPattern, UnrollElementwisePattern,
+               UnrollReductionPattern, UnrollMultiReductionPattern,
+               UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
+               UnrollStorePattern, UnrollBroadcastPattern>(
+      patterns.getContext(), options, benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
new file mode 100644
index 0000000000000..3135268b8d61b
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -test-vector-load-store-unroll --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @unroll_2D_vector_load(
+// CHECK-SAME:  %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
+func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+  // CHECK: %[[C3:.*]] = arith.constant 3 : index
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16>
+  // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+  // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+  // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+  // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+  // CHECK: return %[[V7]] : vector<4x4xf16>
+  %c0 = arith.constant 0 : index
+  %0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+  return %0 : vector<4x4xf16>
+}
+
+// CHECK-LABEL: func.func @unroll_2D_vector_store(
+// CHECK-SAME:  %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
+func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
+  // CHECK: %[[C3:.*]] = arith.constant 3 : index
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16>
+  // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16>
+  // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16>
+  // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16>
+  // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  %c0 = arith.constant 0 : index
+  vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+  return
+}
+
+// CHECK-LABEL: func.func @unroll_vector_load(
+// CHECK-SAME:  %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
+  // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+  // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
+  // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+  // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
+  // CHECK: return %[[V3]] : vector<2x2xf16>
+  %c1 = arith.constant 1 : index
+  %0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+  return %0 : vector<2x2xf16>
+}
+
+// CHECK-LABEL: func.func @unroll_vector_store(
+// CHECK-SAME:  %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) {
+func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) {
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16>
+  // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+  // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16>
+  // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+  %c1 = arith.constant 1 : index
+  vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+  return
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 54aa96ba89a00..8014362a1a6ec 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -292,6 +292,44 @@ struct TestVectorTransferUnrollingPatterns
       llvm::cl::init(false)};
 };
 
+struct TestVectorLoadStoreUnrollPatterns
+    : public PassWrapper<TestVectorLoadStoreUnrollPatterns,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestVectorLoadStoreUnrollPatterns)
+
+  StringRef getArgument() const final {
+    return "test-vector-load-store-unroll";
+  }
+  StringRef getDescription() const final {
+    return "Test unrolling patterns for vector.load and vector.store ops";
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect, arith::ArithDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+
+    // Unroll all vector.load and vector.store ops with rank > 1 to 1D vectors
+    vector::UnrollVectorOptions options;
+    options.setFilterConstraint([](Operation *op) {
+      if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+        return success(loadOp.getType().getRank() > 1);
+      if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+        return success(storeOp.getVectorType().getRank() > 1);
+      return failure();
+    });
+
+    vector::populateVectorUnrollPatterns(patterns, options);
+
+    // Apply the patterns
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestScalarVectorTransferLoweringPatterns
     : public PassWrapper<TestScalarVectorTransferLoweringPatterns,
                          OperationPass<func::FuncOp>> {
@@ -1032,6 +1070,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorTransferUnrollingPatterns>();
 
+  PassRegistration<TestVectorLoadStoreUnrollPatterns>();
+
   PassRegistration<TestScalarVectorTransferLoweringPatterns>();
 
   PassRegistration<TestVectorTransferOpt>();

@llvmbot
Copy link
Member

llvmbot commented Jun 9, 2025

@llvm/pr-subscribers-mlir-vector

Author: Nishant Patel (nbpatel)

Changes

This PR adds unroll patterns for vector.load and vector.store with rank > 1 and unrolls them to 1D load and store. This PR is follow up of #137558


Full diff: https://github.com/llvm/llvm-project/pull/143420.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+125-6)
  • (added) mlir/test/Dialect/Vector/vector-load-store-unroll.mlir (+73)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+40)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fc443ab0d138e..e912a6ef29b21 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -54,6 +54,33 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
   return slicedIndices;
 }
 
+// compute the new indices for vector.load/store by adding offsets to
+// originalIndices.
+// It assumes m <= n (m = offsets.size(), n = originalIndices.size())
+// Last m of originalIndices will be updated.
+static SmallVector<Value> computeIndices(PatternRewriter &rewriter,
+                                         Location loc,
+                                         ArrayRef<Value> originalIndices,
+                                         ArrayRef<int64_t> offsets) {
+  assert(offsets.size() <= originalIndices.size() &&
+         "Offsets should not exceed the number of original indices");
+  SmallVector<Value> indices(originalIndices);
+  auto originalIter = originalIndices.rbegin();
+  auto offsetsIter = offsets.rbegin();
+  auto indicesIter = indices.rbegin();
+  while (offsetsIter != offsets.rend()) {
+    Value original = *originalIter;
+    int64_t offset = *offsetsIter;
+    if (offset != 0)
+      *indicesIter = rewriter.create<arith::AddIOp>(
+          loc, original, rewriter.create<arith::ConstantIndexOp>(loc, offset));
+    originalIter++;
+    offsetsIter++;
+    indicesIter++;
+  }
+  return indices;
+};
+
 // Clones `op` into a new operations that takes `operands` and returns
 // `resultTypes`.
 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
@@ -631,6 +658,98 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
   vector::UnrollVectorOptions options;
 };
 
+struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
+  UnrollLoadPattern(MLIRContext *context,
+                    const vector::UnrollVectorOptions &options,
+                    PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
+
+  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType vecType = loadOp.getVectorType();
+    // Only unroll >1D loads
+    if (vecType.getRank() <= 1)
+      return failure();
+
+    Location loc = loadOp.getLoc();
+    ArrayRef<int64_t> originalShape = vecType.getShape();
+
+    // Target type is a 1D vector of the innermost dimension.
+    auto targetType =
+        VectorType::get(originalShape.back(), vecType.getElementType());
+
+    // Extend the targetShape to the same rank of original shape by padding 1s
+    // for leading dimensions for convenience of computing offsets
+    SmallVector<int64_t> targetShape(originalShape.size(), 1);
+    targetShape.back() = originalShape.back();
+
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, vecType, rewriter.getZeroAttr(vecType));
+
+    SmallVector<Value> originalIndices(loadOp.getIndices().begin(),
+                                       loadOp.getIndices().end());
+
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalShape, targetShape)) {
+      SmallVector<Value> indices =
+          computeIndices(rewriter, loc, originalIndices, offsets);
+      Value slice = rewriter.create<vector::LoadOp>(loc, targetType,
+                                                    loadOp.getBase(), indices);
+      // Insert the slice into the result at the correct position.
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, slice, result, offsets, SmallVector<int64_t>({1}));
+    }
+    rewriter.replaceOp(loadOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
+struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
+  UnrollStorePattern(MLIRContext *context,
+                     const vector::UnrollVectorOptions &options,
+                     PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
+
+  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType vecType = storeOp.getVectorType();
+    // Only unroll >1D stores.
+    if (vecType.getRank() <= 1)
+      return failure();
+
+    Location loc = storeOp.getLoc();
+    ArrayRef<int64_t> originalShape = vecType.getShape();
+
+    // Extend the targetShape to the same rank of original shape by padding 1s
+    // for leading dimensions for convenience of computing offsets
+    SmallVector<int64_t> targetShape(originalShape.size(), 1);
+    targetShape.back() = originalShape.back();
+
+    Value base = storeOp.getBase();
+    Value vector = storeOp.getValueToStore();
+
+    SmallVector<Value> originalIndices(storeOp.getIndices().begin(),
+                                       storeOp.getIndices().end());
+
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalShape, targetShape)) {
+      SmallVector<Value> indices =
+          computeIndices(rewriter, loc, originalIndices, offsets);
+      offsets.pop_back();
+      Value slice = rewriter.create<vector::ExtractOp>(loc, vector, offsets);
+      rewriter.create<vector::StoreOp>(loc, slice, base, indices);
+    }
+    rewriter.eraseOp(storeOp);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
   UnrollBroadcastPattern(MLIRContext *context,
                          const vector::UnrollVectorOptions &options,
@@ -699,10 +818,10 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
 void mlir::vector::populateVectorUnrollPatterns(
     RewritePatternSet &patterns, const UnrollVectorOptions &options,
     PatternBenefit benefit) {
-  patterns
-      .add<UnrollTransferReadPattern, UnrollTransferWritePattern,
-           UnrollContractionPattern, UnrollElementwisePattern,
-           UnrollReductionPattern, UnrollMultiReductionPattern,
-           UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>(
-          patterns.getContext(), options, benefit);
+  patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
+               UnrollContractionPattern, UnrollElementwisePattern,
+               UnrollReductionPattern, UnrollMultiReductionPattern,
+               UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
+               UnrollStorePattern, UnrollBroadcastPattern>(
+      patterns.getContext(), options, benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
new file mode 100644
index 0000000000000..3135268b8d61b
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -test-vector-load-store-unroll --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @unroll_2D_vector_load(
+// CHECK-SAME:  %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
+func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+  // CHECK: %[[C3:.*]] = arith.constant 3 : index
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16>
+  // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+  // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+  // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+  // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+  // CHECK: return %[[V7]] : vector<4x4xf16>
+  %c0 = arith.constant 0 : index
+  %0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+  return %0 : vector<4x4xf16>
+}
+
+// CHECK-LABEL: func.func @unroll_2D_vector_store(
+// CHECK-SAME:  %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
+func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
+  // CHECK: %[[C3:.*]] = arith.constant 3 : index
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16>
+  // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16>
+  // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16>
+  // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16>
+  // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  %c0 = arith.constant 0 : index
+  vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+  return
+}
+
+// CHECK-LABEL: func.func @unroll_vector_load(
+// CHECK-SAME:  %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
+  // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+  // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
+  // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+  // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
+  // CHECK: return %[[V3]] : vector<2x2xf16>
+  %c1 = arith.constant 1 : index
+  %0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+  return %0 : vector<2x2xf16>
+}
+
+// CHECK-LABEL: func.func @unroll_vector_store(
+// CHECK-SAME:  %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) {
+func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) {
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16>
+  // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+  // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16>
+  // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+  %c1 = arith.constant 1 : index
+  vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+  return
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 54aa96ba89a00..8014362a1a6ec 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -292,6 +292,44 @@ struct TestVectorTransferUnrollingPatterns
       llvm::cl::init(false)};
 };
 
+struct TestVectorLoadStoreUnrollPatterns
+    : public PassWrapper<TestVectorLoadStoreUnrollPatterns,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestVectorLoadStoreUnrollPatterns)
+
+  StringRef getArgument() const final {
+    return "test-vector-load-store-unroll";
+  }
+  StringRef getDescription() const final {
+    return "Test unrolling patterns for vector.load and vector.store ops";
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect, arith::ArithDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+
+    // Unroll all vector.load and vector.store ops with rank > 1 to 1D vectors
+    vector::UnrollVectorOptions options;
+    options.setFilterConstraint([](Operation *op) {
+      if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+        return success(loadOp.getType().getRank() > 1);
+      if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+        return success(storeOp.getVectorType().getRank() > 1);
+      return failure();
+    });
+
+    vector::populateVectorUnrollPatterns(patterns, options);
+
+    // Apply the patterns
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestScalarVectorTransferLoweringPatterns
     : public PassWrapper<TestScalarVectorTransferLoweringPatterns,
                          OperationPass<func::FuncOp>> {
@@ -1032,6 +1070,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorTransferUnrollingPatterns>();
 
+  PassRegistration<TestVectorLoadStoreUnrollPatterns>();
+
   PassRegistration<TestScalarVectorTransferLoweringPatterns>();
 
   PassRegistration<TestVectorTransferOpt>();

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG % some minor asks, thanks!

@nbpatel
Copy link
Contributor Author

nbpatel commented Jun 10, 2025

Thanks @banach-space for the feedback. I addressed your feedback. Please take a look again whenever you have some time.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks for the update!

On second scan, it feels that computeIndices could be simplified. Please take a look. Thank you!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for addressing all my comments, the final version is much easier to follow 🙏🏻

@newling , IIRC these patterns were of interest to you as well? @nbpatel , would you mind waiting ~1 day before landing (to give others a chance to take a look)?

@newling
Copy link
Contributor

newling commented Jun 12, 2025

Thanks for moving this from linearization to unrolling, i.e. from #136193 and #137558

As far as I can tell the existing 9 unroll patterns support unrolling to any tile size of the vector. See for example

elementwise:

// CHECK-LABEL: func @elementwise_unroll

transfer_read:

func.func @transfer_read_unroll(%mem : memref<4x4xf32>) -> vector<4x4xf32> {

and the tests in the file you added your tests to in this PR. Can you motivate deviating from this design? For example, I would expect based on the other unroll patterns for provide a 'target shape' of 2x2 to convert a load of 6x6 to 9 loads of 2x2.

I think that transfer_read is a more complex op than vector load. For example, transfer_read is lowered to load here.
So I think it should be possible to unroll load to any tile size, if that's what happens for transfer_reads.

In terms of achieving what is I think your goad of only having rank-1 loads, I propose the following

  1. use unrolling to go from loads of 6x6 to 6 loads of 1x6
  2. linearize the loads of 1x6 to loads of 6

Inn my opinion, having a linearization pattern for 2 above is fine, because it converts one rank-2 op into one rank-1 op. That's what I was kind of driving at when I added this comment:

/// Definition: here 'linearization' means converting a single operation with

I acknowledge that end result is exactly what you had in #136193 !! But I think it makes Vector more modular and intuitive for other users if unrolling and linearization have a clear boundary.

@@ -178,6 +178,16 @@ struct TestVectorUnrollingPatterns
return success(isa<vector::TransposeOp>(op));
}));

populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2, 2})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think {2,2} should be used

@nbpatel
Copy link
Contributor Author

nbpatel commented Jun 13, 2025

Thanks for moving this from linearization to unrolling, i.e. from #136193 and #137558

As far as I can tell the existing 9 unroll patterns support unrolling to any tile size of the vector. See for example

elementwise:

// CHECK-LABEL: func @elementwise_unroll

transfer_read:

func.func @transfer_read_unroll(%mem : memref<4x4xf32>) -> vector<4x4xf32> {

and the tests in the file you added your tests to in this PR. Can you motivate deviating from this design? For example, I would expect based on the other unroll patterns for provide a 'target shape' of 2x2 to convert a load of 6x6 to 9 loads of 2x2.

I think that transfer_read is a more complex op than vector load. For example, transfer_read is lowered to load here. So I think it should be possible to unroll load to any tile size, if that's what happens for transfer_reads.

In terms of achieving what is I think your goad of only having rank-1 loads, I propose the following

  1. use unrolling to go from loads of 6x6 to 6 loads of 1x6
  2. linearize the loads of 1x6 to loads of 6

Inn my opinion, having a linearization pattern for 2 above is fine, because it converts one rank-2 op into one rank-1 op. That's what I was kind of driving at when I added this comment:

/// Definition: here 'linearization' means converting a single operation with

I acknowledge that end result is exactly what you had in #136193 !! But I think it makes Vector more modular and intuitive for other users if unrolling and linearization have a clear boundary.

The motivation for this unrolling was to have 1D loads and stores and hence the target shape is set accordingly. But, I agree with your point of making it more modular to align with vector dialect's flexible unrolling design. Can I address this in a follow up PR?

@newling
Copy link
Contributor

newling commented Jun 13, 2025

The motivation for this unrolling was to have 1D loads and stores and hence the target shape is set accordingly. But, I agree with your point of making it more modular to align with vector dialect's flexible unrolling design. Can I address this in a follow up PR?

The 2 designs seem orthogonal, so IMO it doesn't makes sense to do it as a follow up.

@nbpatel nbpatel requested review from Groverkss and kuhar as code owners June 20, 2025 14:32
@nbpatel
Copy link
Contributor Author

nbpatel commented Jun 20, 2025

Hi @newling , I changed it to follow similar design as the other patterns. Can you take a look?

Copy link

github-actions bot commented Jun 20, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @nbpatel, this looks good to me! My remaining comments are all quite minor

@nbpatel
Copy link
Contributor Author

nbpatel commented Jun 20, 2025

Thanks @nbpatel, this looks good to me! My remaining comments are all quite minor

Thanks, I addressed your comments. Can you please approve if it looks ok?

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nbpatel nbpatel merged commit 9c1ce31 into llvm:main Jun 20, 2025
7 checks passed
Jaddyen pushed a commit to Jaddyen/llvm-project that referenced this pull request Jun 23, 2025
…lvm#143420)

This PR adds unroll patterns for vector.load and vector.store. This PR is follow up of llvm#137558
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants