From 59fbba94908f65eedb8bdd619e425bf97d84b2e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrzej=20Warzy=C5=84ski?= Date: Wed, 20 Sep 2023 09:56:43 +0100 Subject: [PATCH] [mlir][vector] Make ReorderElementwiseOpsOnBroadcast support vector.splat (#66596) Extend `ReorderElementwiseOpsOnBroadcast` so that the broadcasting op could be either `vector.broadcast` (already supported) as well as `vector.splat` (support added in this patch). --- .../Vector/Transforms/VectorTransforms.cpp | 40 ++++++++++++------- .../Dialect/Vector/sink-vector-broadcast.mlir | 39 +++++++++++++++--- 2 files changed, 59 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 207df69929c1c9..b2a5aef5ee62d0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -880,7 +880,7 @@ struct BreakDownVectorBitCast : public OpRewritePattern { std::function controlFn; }; -/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex: +/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: /// ``` /// %a = vector.broadcast %arg1 : index to vector<1x4xindex> /// %b = vector.broadcast %arg2 : index to vector<1x4xindex> @@ -891,6 +891,9 @@ struct BreakDownVectorBitCast : public OpRewritePattern { /// %r = arith.addi %arg0, %arg1 : index /// %b = vector.broadcast %r : index to vector<1x4xindex> /// ``` +/// +/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting +/// ops. struct ReorderElementwiseOpsOnBroadcast final : public OpTraitRewritePattern { using OpTraitRewritePattern::OpTraitRewritePattern; @@ -903,35 +906,42 @@ struct ReorderElementwiseOpsOnBroadcast final if (!OpTrait::hasElementwiseMappableTraits(op)) return failure(); - // Get the type of the first operand - auto firstBcast = op->getOperand(0).getDefiningOp(); - if (!firstBcast) + // Get the type of the lhs operand + auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp(); + if (!lhsBcastOrSplat || + !isa(*lhsBcastOrSplat)) return failure(); - auto firstOpType = firstBcast.getOperand().getType(); + auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType(); - // Make sure that operands are "broadcast"ed from identical (scalar or - // vector) types. That indicates that it's safe to skip the broadcasting of - // operands. - if (!llvm::all_of(op->getOperands(), [&firstOpType](Value val) { + // Make sure that all operands are broadcast from identical types: + // * scalar (`vector.broadcast` + `vector.splat`), or + // * vector (`vector.broadcast`). + // Otherwise the re-ordering wouldn't be safe. + if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) { auto bcast = val.getDefiningOp(); - return (bcast && (bcast.getOperand().getType() == firstOpType)); + if (bcast) + return (bcast.getOperand().getType() == lhsBcastOrSplatType); + auto splat = val.getDefiningOp(); + if (splat) + return (splat.getOperand().getType() == lhsBcastOrSplatType); + return false; })) { return failure(); } - // Collect the source values + // Collect the source values before broadcasting SmallVector srcValues; srcValues.reserve(op->getNumOperands()); - for (Value operand : op->getOperands()) { - srcValues.push_back( - operand.getDefiningOp().getOperand()); + srcValues.push_back(operand.getDefiningOp()->getOperand(0)); } + // Create the "elementwise" Op Operation *elementwiseOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, - firstOpType, op->getAttrs()); + lhsBcastOrSplatType, op->getAttrs()); + // Replace the original Op with the elementwise Op auto vectorType = op->getResultTypes()[0]; rewriter.replaceOpWithNewOp( op, vectorType, elementwiseOp->getResults()); diff --git a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir b/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir index fcf9815f6f6f1d..d9d2f44e6f16c1 100644 --- a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir +++ b/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir @@ -1,13 +1,12 @@ // RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s -// CHECK-LABEL: func.func @broadcast_scalar( +// CHECK-LABEL: func.func @broadcast_scalar_with_bcast( // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> { // CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex> // CHECK: return %[[BCAST]] : vector<1x4xindex> -// CHECK: } -func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> { +func.func @broadcast_scalar_with_bcast( %arg1: index, %arg2: index) -> vector<1x4xindex> { %0 = vector.broadcast %arg1 : index to vector<1x4xindex> %1 = vector.broadcast %arg2 : index to vector<1x4xindex> %2 = arith.addi %0, %1 : vector<1x4xindex> @@ -16,13 +15,27 @@ func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> { // ----- +// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat( +// CHECK-SAME: %[[ARG1:.*]]: index, +// CHECK-SAME: %[[ARG2:.*]]: index) -> vector<1x4xindex> { +// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex> +// CHECK: return %[[BCAST]] : vector<1x4xindex> +func.func @broadcast_scalar_with_bcast_and_splat( %arg1: index, %arg2: index) -> vector<1x4xindex> { + %0 = vector.splat %arg1 : vector<1x4xindex> + %1 = vector.broadcast %arg2 : index to vector<1x4xindex> + %2 = arith.addi %0, %1 : vector<1x4xindex> + return %2 : vector<1x4xindex> +} + +// ----- + // CHECK-LABEL: func.func @broadcast_vector( // CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>, // CHECK-SAME: %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> { // CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32> // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32> // CHECK: return %[[BCAST]] : vector<3x4xf32> -// CHECK: } func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> { %arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32> @@ -30,6 +43,23 @@ func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vect %2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32> return %2 : vector<3x4xf32> } + +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_and_vec( +// CHECK-SAME: %[[ARG1:.*]]: index, +// CHECK-SAME: %[[ARG2:.*]]: vector<4xindex>) -> vector<1x4xindex> { +// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x4xindex> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<4xindex> to vector<1x4xindex> +// CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x4xindex> +// CHECK: return %[[ADD]] : vector<1x4xindex> +func.func @broadcast_scalar_and_vec( %arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> { + %0 = vector.splat %arg1 : vector<1x4xindex> + %1 = vector.broadcast %arg2 : vector<4xindex> to vector<1x4xindex> + %2 = arith.addi %0, %1 : vector<1x4xindex> + return %2 : vector<1x4xindex> +} + // ----- // CHECK-LABEL: func.func @broadcast_vector_and_scalar( @@ -38,7 +68,6 @@ func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vect // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<4xi32> // CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32> // CHECK: return %[[ADD]] : vector<4xi32> -// CHECK: } func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> { %arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>