Skip to content

[mlir][vector] Add a check to ensure input vector rank equals target shape rank #149239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

CoTinker
Copy link
Contributor

@CoTinker CoTinker commented Jul 17, 2025

The crash is caused because, during IR transformation, the vector-unrolling pass (using ExtractStridedSliceOp) attempts to slice an
input vector of higher rank using a target vector of lower rank, which is not supported. Fixes #148368.

CoTinker added 2 commits July 17, 2025 11:33
…shape rank

The crash is caused because, during IR transformation, the
vector-unrolling pass (using ExtractStridedSliceOp) attempts to slice an
input vector of higher rank using a target vector of lower rank, which
is not supported.
@llvmbot
Copy link
Member

llvmbot commented Jul 17, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

The crash is caused because, during IR transformation, the vector-unrolling pass (using ExtractStridedSliceOp) attempts to slice an
input vector of higher rank using a target vector of lower rank, which is not supported. Fixed #148368.


Full diff: https://github.com/llvm/llvm-project/pull/149239.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+15-1)
  • (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+22)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 693f4f955994d..be911901c2afc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -169,7 +169,13 @@ struct UnrollTransferReadPattern
     auto sourceVectorType = readOp.getVectorType();
     SmallVector<int64_t> strides(targetShape->size(), 1);
     Location loc = readOp.getLoc();
-    ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
+    ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
+    // Bail-out if rank(source) != rank(target). The main limitation here is the
+    // fact that `InsertStridedSliceOp` requires the rank for the input and
+    // output to match. If needed, we can relax this later.
+    if (originalSize.size() != targetShape->size())
+      return rewriter.notifyMatchFailure(
+          readOp, "expected source vector rank to match target shape rank");
 
     // Prepare the result vector;
     Value result = rewriter.create<arith::ConstantOp>(
@@ -224,6 +230,14 @@ struct UnrollTransferWritePattern
     SmallVector<int64_t> strides(targetShape->size(), 1);
     Location loc = writeOp.getLoc();
     ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
+    // Bail-out if rank(source) != rank(target). The main limitation here is the
+    // fact that `ExtractStridedSlice` requires the rank for the input and
+    // output to match. If needed, we can relax this later.
+    if (originalSize.size() != targetShape->size())
+      return rewriter.notifyMatchFailure(
+          writeOp,
+          "expected source input vector rank to match target shape rank");
+
     SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
                                        writeOp.getIndices().end());
     SmallVector<int64_t> loopOrder =
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e129cd5c40b9c..8f6945468feb3 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -420,3 +420,25 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
   // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
   // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
   // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
+
+func.func @vector_transfer_read(%arg0: memref<6x34x62xi8>) -> vector<6x34x62xi8> {
+  %c0_i8 = arith.constant 0 : i8
+  %c0 = arith.constant 0 : index
+  %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %c0_i8 : memref<6x34x62xi8>, vector<6x34x62xi8>
+  return %0 : vector<6x34x62xi8>
+}
+// CHECK-LABEL: func @vector_transfer_read
+//   CHECK-NOT: vector.intert_strided_slice
+//       CHECK: vector.transfer_read
+//       CHECK: return
+
+func.func @vector_transfer_write(%arg0: vector<6x34x62xi8>) {
+  %c0 = arith.constant 0 : index
+  %alloc = memref.alloc() : memref<6x34x62xi8>
+  vector.transfer_write %arg0, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8>
+  return
+}
+// CHECK-LABEL: func @vector_transfer_write
+//   CHECK-NOT: vector.extract_strided_slice
+//       CHECK: vector.transfer_write
+//       CHECK: return

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
2 participants