Skip to content

Commit

Permalink
[mlir][Vector] Support 0-d vectors natively in TransferOpReduceRank (l…
Browse files Browse the repository at this point in the history
…lvm#112907)

Since
llvm@ddf2d62
, 0-d vectors are supported in VectorType. This patch removes 0-d vector
handling with scalars for the TransferOpReduceRank pattern. This pattern
specifically introduces tensor.extract_slice during vectorization,
causing vectorization to not fold transfer_read/transfer_write slices
properly. The changes in vectorization test files reflect this.

There are other places where lowering patterns are still side-stepping
from handling 0-d vectors properly, by turning them into scalars, but
this patch only focuses on the vector.transfer_x patterns.
  • Loading branch information
Groverkss authored and EricWF committed Oct 22, 2024
1 parent 81f6c99 commit a868e74
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 35 deletions.
21 changes: 0 additions & 21 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,31 +358,10 @@ struct TransferOpReduceRank
op, "map is not a minor identity with broadcasting");
}

// TODO: support zero-dimension vectors natively. See:
// https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
// In the meantime, lower these to a scalar load when they pop up.
if (reducedShapeRank == 0) {
Value newRead;
if (isa<TensorType>(op.getShapedType())) {
newRead = rewriter.create<tensor::ExtractOp>(
op.getLoc(), op.getSource(), op.getIndices());
} else {
newRead = rewriter.create<memref::LoadOp>(
op.getLoc(), originalVecType.getElementType(), op.getSource(),
op.getIndices());
}
return rewriter
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
.getVector();
}

SmallVector<int64_t> newShape(
originalVecType.getShape().take_back(reducedShapeRank));
SmallVector<bool> newScalableDims(
originalVecType.getScalableDims().take_back(reducedShapeRank));
// Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
if (newShape.empty())
return rewriter.notifyMatchFailure(op, "rank-reduced vector is 0-d");

VectorType newReadType = VectorType::get(
newShape, originalVecType.getElementType(), newScalableDims);
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,8 @@ func.func @transfer_read_within_async_execute(%A : memref<2x2xf32>) -> !async.to

// CHECK-LABEL: transfer_read_with_tensor
func.func @transfer_read_with_tensor(%arg: tensor<f32>) -> vector<1xf32> {
// CHECK: %[[EXTRACTED:.*]] = tensor.extract %{{.*}}[] : tensor<f32>
// CHECK-NEXT: %[[RESULT:.*]] = vector.broadcast %[[EXTRACTED]] : f32 to vector<1xf32>
// CHECK: %[[EXTRACTED:.*]] = vector.transfer_read %{{.*}}[], %{{.*}} : tensor<f32>, vector<f32>
// CHECK-NEXT: %[[RESULT:.*]] = vector.broadcast %[[EXTRACTED]] : vector<f32> to vector<1xf32>
// CHECK-NEXT: return %[[RESULT]] : vector<1xf32>
%f0 = arith.constant 0.0 : f32
%0 = vector.transfer_read %arg[], %f0 {permutation_map = affine_map<()->(0)>} :
Expand Down
20 changes: 10 additions & 10 deletions mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,7 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
// CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>

// Same as example above, but reading into a column tensor. Note that after the
// vectorizatoin, the `TransferOpReduceRank` will replace
// `vector.transfer_read` with `tensor.extract -> scalar`.
// Same as example above, but reading into a column tensor.

// TODO: Currently this fails to vectorise when the indices are non-constant.

Expand All @@ -162,9 +160,10 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic_column(
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic_column(
// CHECK-SAME: %[[INPUT:.*]]: tensor<3x3x3xf32>,
// CHECK-SAME: %[[OUTPUT:.*]]: tensor<3x1x1xf32>)
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : tensor<3x3x3xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : f32 to vector<3x1x1xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[READ:.*]] = vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[CST_0]] : tensor<3x3x3xf32>, vector<f32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[READ]] : vector<f32> to vector<3x1x1xf32>
// CHECK: %[[RES:.*]] = vector.transfer_write %[[BCAST]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<3x1x1xf32>, tensor<3x1x1xf32>
// CHECK: return %[[RES]] : tensor<3x1x1xf32>

Expand Down Expand Up @@ -541,8 +540,9 @@ func.func @vectorize_nd_tensor_extract_with_tensor_extract(%input_1: tensor<1x20
// CHECK-SAME: %[[INPUT_2:.*]]: tensor<257x24xf32>,
// CHECK: %[[EXTRACTED_0_IDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[EXTRACTED_0_IDX_1:.*]] = vector.extractelement %{{.*}}[%{{.*}} : i32] : vector<4xindex>
// First `tensor.extract` from the generic Op - loop invariant scalar load.
// CHECK: tensor.extract %[[INPUT_1]][%[[EXTRACTED_0_IDX_0]], %[[EXTRACTED_0_IDX_1]]] : tensor<1x20xi32>
// First `vector.transfer_read` from the generic Op - loop invariant scalar load.
// CHECK: vector.transfer_read %[[INPUT_1]][%[[EXTRACTED_0_IDX_0]], %[[EXTRACTED_0_IDX_1]]]
// CHECK-SAME: tensor<1x20xi32>, vector<i32>
// The following `tensor.extract` from the generic Op s a contiguous load (all Ops used
// for address calculation also satisfy the required conditions).
// CHECK: vector.transfer_read %[[INPUT_2]][%{{.*}}, %{{.*}}, %{{.*}} {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32>
Expand Down Expand Up @@ -745,8 +745,8 @@ func.func @vectorize_0d_tensor_extract(%arg0: tensor<f32>, %arg2: tensor<1x1x3xf

// CHECK-LABEL: func.func @vectorize_0d_tensor_extract(
// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[ARG_0]][] : tensor<f32>
// CHECK: vector.broadcast %[[EXTRACT]] : f32 to vector<1x1x3xf32>
// CHECK: %[[EXTRACT:.*]] = vector.transfer_read %[[ARG_0]][], %{{.+}} : tensor<f32>
// CHECK: vector.broadcast %[[EXTRACT]] : vector<f32> to vector<1x1x3xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf
func.func @vector_transfer_ops_0d_tensor(%src: tensor<f32>) -> vector<1xf32> {
%f0 = arith.constant 0.0 : f32

// CHECK-NEXT: %[[S:.*]] = tensor.extract %[[SRC]][] : tensor<f32>
// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<1xf32>
// CHECK: %[[S:.*]] = vector.transfer_read %[[SRC]][]
// CHECK: %[[V:.*]] = vector.broadcast %[[S]] : vector<f32> to vector<1xf32>
%res = vector.transfer_read %src[], %f0 {in_bounds = [true], permutation_map = affine_map<()->(0)>} :
tensor<f32>, vector<1xf32>

Expand Down

0 comments on commit a868e74

Please sign in to comment.