diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 693f4f955994d..be911901c2afc 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -169,7 +169,13 @@ struct UnrollTransferReadPattern auto sourceVectorType = readOp.getVectorType(); SmallVector strides(targetShape->size(), 1); Location loc = readOp.getLoc(); - ArrayRef originalSize = readOp.getVectorType().getShape(); + ArrayRef originalSize = sourceVectorType.getShape(); + // Bail-out if rank(source) != rank(target). The main limitation here is the + // fact that `InsertStridedSliceOp` requires the rank for the input and + // output to match. If needed, we can relax this later. + if (originalSize.size() != targetShape->size()) + return rewriter.notifyMatchFailure( + readOp, "expected source vector rank to match target shape rank"); // Prepare the result vector; Value result = rewriter.create( @@ -224,6 +230,14 @@ struct UnrollTransferWritePattern SmallVector strides(targetShape->size(), 1); Location loc = writeOp.getLoc(); ArrayRef originalSize = sourceVectorType.getShape(); + // Bail-out if rank(source) != rank(target). The main limitation here is the + // fact that `ExtractStridedSlice` requires the rank for the input and + // output to match. If needed, we can relax this later. + if (originalSize.size() != targetShape->size()) + return rewriter.notifyMatchFailure( + writeOp, + "expected source input vector rank to match target shape rank"); + SmallVector originalIndices(writeOp.getIndices().begin(), writeOp.getIndices().end()); SmallVector loopOrder = diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index e129cd5c40b9c..8f6945468feb3 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -420,3 +420,25 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) { // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16> // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16> // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16> + +func.func @vector_transfer_read(%arg0: memref<6x34x62xi8>) -> vector<6x34x62xi8> { + %c0_i8 = arith.constant 0 : i8 + %c0 = arith.constant 0 : index + %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %c0_i8 : memref<6x34x62xi8>, vector<6x34x62xi8> + return %0 : vector<6x34x62xi8> +} +// CHECK-LABEL: func @vector_transfer_read +// CHECK-NOT: vector.intert_strided_slice +// CHECK: vector.transfer_read +// CHECK: return + +func.func @vector_transfer_write(%arg0: vector<6x34x62xi8>) { + %c0 = arith.constant 0 : index + %alloc = memref.alloc() : memref<6x34x62xi8> + vector.transfer_write %arg0, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8> + return +} +// CHECK-LABEL: func @vector_transfer_write +// CHECK-NOT: vector.extract_strided_slice +// CHECK: vector.transfer_write +// CHECK: return