From 8323ca8956aec45713231e06768a0b330f83cce1 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Wed, 30 Oct 2024 13:54:17 -0400 Subject: [PATCH] Revert "[mlir][Vector] Support 0-d vectors natively in TransferOpReduceRank (#112907)" This reverts commit 1004865f1ca41a9581da8747f34b29862d3ebc3d. --- .../Vector/Transforms/LowerVectorTransfer.cpp | 21 +++++++++++++++++++ .../Conversion/VectorToSCF/vector-to-scf.mlir | 4 ++-- .../Linalg/vectorize-tensor-extract.mlir | 20 +++++++++--------- .../vector-transfer-to-vector-load-store.mlir | 4 ++-- 4 files changed, 35 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index f9428a4ce28640..344cfc0cbffb93 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -358,10 +358,31 @@ struct TransferOpReduceRank op, "map is not a minor identity with broadcasting"); } + // TODO: support zero-dimension vectors natively. See: + // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097. + // In the meantime, lower these to a scalar load when they pop up. + if (reducedShapeRank == 0) { + Value newRead; + if (isa(op.getShapedType())) { + newRead = rewriter.create( + op.getLoc(), op.getSource(), op.getIndices()); + } else { + newRead = rewriter.create( + op.getLoc(), originalVecType.getElementType(), op.getSource(), + op.getIndices()); + } + return rewriter + .create(op.getLoc(), originalVecType, newRead) + .getVector(); + } + SmallVector newShape( originalVecType.getShape().take_back(reducedShapeRank)); SmallVector newScalableDims( originalVecType.getScalableDims().take_back(reducedShapeRank)); + // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering. + if (newShape.empty()) + return rewriter.notifyMatchFailure(op, "rank-reduced vector is 0-d"); VectorType newReadType = VectorType::get( newShape, originalVecType.getElementType(), newScalableDims); diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir index 5a6da3a06387a5..c55a0c558bc2f1 100644 --- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir @@ -503,8 +503,8 @@ func.func @transfer_read_within_async_execute(%A : memref<2x2xf32>) -> !async.to // CHECK-LABEL: transfer_read_with_tensor func.func @transfer_read_with_tensor(%arg: tensor) -> vector<1xf32> { - // CHECK: %[[EXTRACTED:.*]] = vector.transfer_read %{{.*}}[], %{{.*}} : tensor, vector - // CHECK-NEXT: %[[RESULT:.*]] = vector.broadcast %[[EXTRACTED]] : vector to vector<1xf32> + // CHECK: %[[EXTRACTED:.*]] = tensor.extract %{{.*}}[] : tensor + // CHECK-NEXT: %[[RESULT:.*]] = vector.broadcast %[[EXTRACTED]] : f32 to vector<1xf32> // CHECK-NEXT: return %[[RESULT]] : vector<1xf32> %f0 = arith.constant 0.0 : f32 %0 = vector.transfer_read %arg[], %f0 {permutation_map = affine_map<()->(0)>} : diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir index e611a8e22ee23f..3560ab2312a2e9 100644 --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir @@ -136,7 +136,9 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic( // CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32> // CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32> -// Same as example above, but reading into a column tensor. +// Same as example above, but reading into a column tensor. Note that after the +// vectorizatoin, the `TransferOpReduceRank` will replace +// `vector.transfer_read` with `tensor.extract -> scalar`. // TODO: Currently this fails to vectorise when the indices are non-constant. @@ -160,10 +162,9 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic_column( // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic_column( // CHECK-SAME: %[[INPUT:.*]]: tensor<3x3x3xf32>, // CHECK-SAME: %[[OUTPUT:.*]]: tensor<3x1x1xf32>) -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[READ:.*]] = vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[CST_0]] : tensor<3x3x3xf32>, vector -// CHECK: %[[BCAST:.*]] = vector.broadcast %[[READ]] : vector to vector<3x1x1xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : tensor<3x3x3xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : f32 to vector<3x1x1xf32> // CHECK: %[[RES:.*]] = vector.transfer_write %[[BCAST]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<3x1x1xf32>, tensor<3x1x1xf32> // CHECK: return %[[RES]] : tensor<3x1x1xf32> @@ -540,9 +541,8 @@ func.func @vectorize_nd_tensor_extract_with_tensor_extract(%input_1: tensor<1x20 // CHECK-SAME: %[[INPUT_2:.*]]: tensor<257x24xf32>, // CHECK: %[[EXTRACTED_0_IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[EXTRACTED_0_IDX_1:.*]] = vector.extractelement %{{.*}}[%{{.*}} : i32] : vector<4xindex> -// First `vector.transfer_read` from the generic Op - loop invariant scalar load. -// CHECK: vector.transfer_read %[[INPUT_1]][%[[EXTRACTED_0_IDX_0]], %[[EXTRACTED_0_IDX_1]]] -// CHECK-SAME: tensor<1x20xi32>, vector +// First `tensor.extract` from the generic Op - loop invariant scalar load. +// CHECK: tensor.extract %[[INPUT_1]][%[[EXTRACTED_0_IDX_0]], %[[EXTRACTED_0_IDX_1]]] : tensor<1x20xi32> // The following `tensor.extract` from the generic Op s a contiguous load (all Ops used // for address calculation also satisfy the required conditions). // CHECK: vector.transfer_read %[[INPUT_2]][%{{.*}}, %{{.*}}, %{{.*}} {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32> @@ -745,8 +745,8 @@ func.func @vectorize_0d_tensor_extract(%arg0: tensor, %arg2: tensor<1x1x3xf // CHECK-LABEL: func.func @vectorize_0d_tensor_extract( // CHECK-SAME: %[[ARG_0:.*]]: tensor -// CHECK: %[[EXTRACT:.*]] = vector.transfer_read %[[ARG_0]][], %{{.+}} : tensor -// CHECK: vector.broadcast %[[EXTRACT]] : vector to vector<1x1x3xf32> +// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[ARG_0]][] : tensor +// CHECK: vector.broadcast %[[EXTRACT]] : f32 to vector<1x1x3xf32> module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index f90111b4c88618..4d8e4a8296fb5a 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -26,8 +26,8 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref, %vec: vector<1x1x1xf func.func @vector_transfer_ops_0d_tensor(%src: tensor) -> vector<1xf32> { %f0 = arith.constant 0.0 : f32 -// CHECK: %[[S:.*]] = vector.transfer_read %[[SRC]][] -// CHECK: %[[V:.*]] = vector.broadcast %[[S]] : vector to vector<1xf32> +// CHECK-NEXT: %[[S:.*]] = tensor.extract %[[SRC]][] : tensor +// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<1xf32> %res = vector.transfer_read %src[], %f0 {in_bounds = [true], permutation_map = affine_map<()->(0)>} : tensor, vector<1xf32>