From 6f4c849d041c77ca85caab60ab83465d9810f180 Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Thu, 17 Jul 2025 11:33:46 +0800 Subject: [PATCH 1/2] [mlir][vector] Add a check to ensure input vector rank equals target shape rank The crash is caused because, during IR transformation, the vector-unrolling pass (using ExtractStridedSliceOp) attempts to slice an input vector of higher rank using a target vector of lower rank, which is not supported. --- .../Dialect/Vector/Transforms/VectorUnroll.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 693f4f955994d..be911901c2afc 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -169,7 +169,13 @@ struct UnrollTransferReadPattern auto sourceVectorType = readOp.getVectorType(); SmallVector strides(targetShape->size(), 1); Location loc = readOp.getLoc(); - ArrayRef originalSize = readOp.getVectorType().getShape(); + ArrayRef originalSize = sourceVectorType.getShape(); + // Bail-out if rank(source) != rank(target). The main limitation here is the + // fact that `InsertStridedSliceOp` requires the rank for the input and + // output to match. If needed, we can relax this later. + if (originalSize.size() != targetShape->size()) + return rewriter.notifyMatchFailure( + readOp, "expected source vector rank to match target shape rank"); // Prepare the result vector; Value result = rewriter.create( @@ -224,6 +230,14 @@ struct UnrollTransferWritePattern SmallVector strides(targetShape->size(), 1); Location loc = writeOp.getLoc(); ArrayRef originalSize = sourceVectorType.getShape(); + // Bail-out if rank(source) != rank(target). The main limitation here is the + // fact that `ExtractStridedSlice` requires the rank for the input and + // output to match. If needed, we can relax this later. + if (originalSize.size() != targetShape->size()) + return rewriter.notifyMatchFailure( + writeOp, + "expected source input vector rank to match target shape rank"); + SmallVector originalIndices(writeOp.getIndices().begin(), writeOp.getIndices().end()); SmallVector loopOrder = From f17fca627138541759303f0ac068a0d5f6c0381d Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Thu, 17 Jul 2025 11:35:41 +0800 Subject: [PATCH 2/2] Add tests --- .../Dialect/Vector/vector-unroll-options.mlir | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index e129cd5c40b9c..8f6945468feb3 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -420,3 +420,25 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) { // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16> // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16> // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16> + +func.func @vector_transfer_read(%arg0: memref<6x34x62xi8>) -> vector<6x34x62xi8> { + %c0_i8 = arith.constant 0 : i8 + %c0 = arith.constant 0 : index + %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %c0_i8 : memref<6x34x62xi8>, vector<6x34x62xi8> + return %0 : vector<6x34x62xi8> +} +// CHECK-LABEL: func @vector_transfer_read +// CHECK-NOT: vector.intert_strided_slice +// CHECK: vector.transfer_read +// CHECK: return + +func.func @vector_transfer_write(%arg0: vector<6x34x62xi8>) { + %c0 = arith.constant 0 : index + %alloc = memref.alloc() : memref<6x34x62xi8> + vector.transfer_write %arg0, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8> + return +} +// CHECK-LABEL: func @vector_transfer_write +// CHECK-NOT: vector.extract_strided_slice +// CHECK: vector.transfer_write +// CHECK: return