Skip to content

[MLIR] [Vector] Linearization patterns for vector.load and vector.store #145115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Jun 20, 2025

This PR add inearizarion pattern for vector.load and vector.store. It is follow up PR to #143420 (comment)

@nbpatel
Copy link
Contributor Author

nbpatel commented Jun 20, 2025

@newling following up on #143420 (comment)

For 2,
Is this what you meant? created a draft PR because that would be quicker

@newling
Copy link
Contributor

newling commented Jun 20, 2025

@newling following up on #143420 (comment)

For 2, Is this what you meant? created a draft PR because that would be quicker

Yeah, this is what I meant 👍

Ideally we'll eventually have something like flattening of transfer_read, done here. i.e. linearize even when there is more than 1 dimension of size > 1, and it needn't be the inner-most dim. But I guess that can wait. FWIW IMO that transfer_read code should be in VectorLinearize too, I mentioned that at the bottom of this comment. And the vector.load linearization code could probably then reuse some of it. Something for the future, maybe!

@nbpatel
Copy link
Contributor Author

nbpatel commented Jun 23, 2025

@newling can you review it as well?

@llvmbot
Copy link
Member

llvmbot commented Jun 23, 2025

@llvm/pr-subscribers-mlir-vector

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/145115.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+69-2)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+23)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..f0b77da5acd02 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -623,6 +623,73 @@ struct LinearizeVectorCreateMask final
   }
 };
 
+/// This pattern linearizes vector.load from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+///   vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+/// is converted to:
+///   vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
+///   vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
+                      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType vecTy = loadOp.getType();
+    if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+      return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported");
+    auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+                                    vecTy.isScalable());
+    auto newLoad = rewriter.create<vector::LoadOp>(
+        loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
+    auto shapeCast = rewriter.create<vector::ShapeCastOp>(
+        loadOp.getLoc(), vecTy, newLoad.getResult());
+    rewriter.replaceOp(loadOp, shapeCast.getResult());
+    return success();
+  }
+};
+
+/// This pattern linearizes vector.store from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+///   vector.store %arg0, %arg1[%c0, %c0]
+///     : vector<1x4xf32>, memref<1x4xf32>
+/// is converted to:
+///   vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+///   vector.store %arg0, %arg1[%c0, %%c0]
+///     : vector<4xf32>, memref<1x4xf32>
+struct LinearizeVectorStore final
+    : public OpConversionPattern<vector::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
+                       PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType vecTy = storeOp.getValueToStore().getType();
+    if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+      return rewriter.notifyMatchFailure(storeOp, "only vector<1xN> supported");
+    auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+                                    vecTy.isScalable());
+
+    Value valueToStore = adaptor.getValueToStore();
+    if (valueToStore.getType() != linearTy) {
+      valueToStore = rewriter.create<vector::ShapeCastOp>(
+          storeOp.getLoc(), linearTy, valueToStore);
+    }
+
+    rewriter.replaceOpWithNewOp<vector::StoreOp>(
+        storeOp, valueToStore, adaptor.getBase(), adaptor.getIndices());
+    return success();
+  }
+};
+
 } // namespace
 
 /// This method defines the set of operations that are linearizable, and hence
@@ -714,8 +781,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
     RewritePatternSet &patterns) {
   patterns
       .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
-           LinearizeVectorSplat, LinearizeVectorCreateMask>(
-          typeConverter, patterns.getContext());
+           LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
+           LinearizeVectorStore>(typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9cbf319ffddb2..fa0436792d3f0 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -464,3 +464,26 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto
   %0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
   return %0 : vector<1x[16]xi1>
 }
+
+// CHECK-LABEL: linearize_vector_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>) -> vector<1x4xf32>
+func.func @linearize_vector_load(%arg0: memref<1x4xf32>) -> vector<1x4xf32> {
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+  // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32>
+  // CHECK: return %[[CAST]] : vector<1x4xf32>
+  %c0 = arith.constant 0 : index
+  %0 = vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+  return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: linearize_vector_store
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
+func.func @linearize_vector_store(%arg0: memref<1x4xf32>, %arg1: vector<1x4xf32>) {
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32>
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+  // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+  %c0 = arith.constant 0 : index
+  vector.store %arg1, %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+  return
+}

@llvmbot
Copy link
Member

llvmbot commented Jun 23, 2025

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/145115.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+69-2)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+23)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..f0b77da5acd02 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -623,6 +623,73 @@ struct LinearizeVectorCreateMask final
   }
 };
 
+/// This pattern linearizes vector.load from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+///   vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+/// is converted to:
+///   vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
+///   vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
+                      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType vecTy = loadOp.getType();
+    if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+      return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported");
+    auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+                                    vecTy.isScalable());
+    auto newLoad = rewriter.create<vector::LoadOp>(
+        loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
+    auto shapeCast = rewriter.create<vector::ShapeCastOp>(
+        loadOp.getLoc(), vecTy, newLoad.getResult());
+    rewriter.replaceOp(loadOp, shapeCast.getResult());
+    return success();
+  }
+};
+
+/// This pattern linearizes vector.store from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+///   vector.store %arg0, %arg1[%c0, %c0]
+///     : vector<1x4xf32>, memref<1x4xf32>
+/// is converted to:
+///   vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+///   vector.store %arg0, %arg1[%c0, %%c0]
+///     : vector<4xf32>, memref<1x4xf32>
+struct LinearizeVectorStore final
+    : public OpConversionPattern<vector::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
+                       PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType vecTy = storeOp.getValueToStore().getType();
+    if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+      return rewriter.notifyMatchFailure(storeOp, "only vector<1xN> supported");
+    auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+                                    vecTy.isScalable());
+
+    Value valueToStore = adaptor.getValueToStore();
+    if (valueToStore.getType() != linearTy) {
+      valueToStore = rewriter.create<vector::ShapeCastOp>(
+          storeOp.getLoc(), linearTy, valueToStore);
+    }
+
+    rewriter.replaceOpWithNewOp<vector::StoreOp>(
+        storeOp, valueToStore, adaptor.getBase(), adaptor.getIndices());
+    return success();
+  }
+};
+
 } // namespace
 
 /// This method defines the set of operations that are linearizable, and hence
@@ -714,8 +781,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
     RewritePatternSet &patterns) {
   patterns
       .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
-           LinearizeVectorSplat, LinearizeVectorCreateMask>(
-          typeConverter, patterns.getContext());
+           LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
+           LinearizeVectorStore>(typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9cbf319ffddb2..fa0436792d3f0 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -464,3 +464,26 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto
   %0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
   return %0 : vector<1x[16]xi1>
 }
+
+// CHECK-LABEL: linearize_vector_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>) -> vector<1x4xf32>
+func.func @linearize_vector_load(%arg0: memref<1x4xf32>) -> vector<1x4xf32> {
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+  // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32>
+  // CHECK: return %[[CAST]] : vector<1x4xf32>
+  %c0 = arith.constant 0 : index
+  %0 = vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+  return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: linearize_vector_store
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
+func.func @linearize_vector_store(%arg0: memref<1x4xf32>, %arg1: vector<1x4xf32>) {
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32>
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+  // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+  %c0 = arith.constant 0 : index
+  vector.store %arg1, %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+  return
+}

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I added 2 small comments about simplifying and generalizing a bit.

@nbpatel
Copy link
Contributor Author

nbpatel commented Jun 23, 2025

Thanks! I added 2 small comments about simplifying and generalizing a bit.

Addressed the feedback, thanks :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants