Skip to content

Commit

Permalink
[mlir][vector] Make ReorderElementwiseOpsOnBroadcast support vector.s…
Browse files Browse the repository at this point in the history
…plat (llvm#66596)

Extend `ReorderElementwiseOpsOnBroadcast` so that the broadcasting op
could be either `vector.broadcast` (already supported) as well as
`vector.splat` (support added in this patch).
  • Loading branch information
banach-space authored Sep 20, 2023
1 parent afd7db4 commit 59fbba9
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 20 deletions.
40 changes: 25 additions & 15 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
std::function<bool(BitCastOp)> controlFn;
};

/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
/// ```
/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
Expand All @@ -891,6 +891,9 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
/// %r = arith.addi %arg0, %arg1 : index
/// %b = vector.broadcast %r : index to vector<1x4xindex>
/// ```
///
/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
/// ops.
struct ReorderElementwiseOpsOnBroadcast final
: public OpTraitRewritePattern<OpTrait::Elementwise> {
using OpTraitRewritePattern::OpTraitRewritePattern;
Expand All @@ -903,35 +906,42 @@ struct ReorderElementwiseOpsOnBroadcast final
if (!OpTrait::hasElementwiseMappableTraits(op))
return failure();

// Get the type of the first operand
auto firstBcast = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
if (!firstBcast)
// Get the type of the lhs operand
auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
if (!lhsBcastOrSplat ||
!isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
return failure();
auto firstOpType = firstBcast.getOperand().getType();
auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();

// Make sure that operands are "broadcast"ed from identical (scalar or
// vector) types. That indicates that it's safe to skip the broadcasting of
// operands.
if (!llvm::all_of(op->getOperands(), [&firstOpType](Value val) {
// Make sure that all operands are broadcast from identical types:
// * scalar (`vector.broadcast` + `vector.splat`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
auto bcast = val.getDefiningOp<vector::BroadcastOp>();
return (bcast && (bcast.getOperand().getType() == firstOpType));
if (bcast)
return (bcast.getOperand().getType() == lhsBcastOrSplatType);
auto splat = val.getDefiningOp<vector::SplatOp>();
if (splat)
return (splat.getOperand().getType() == lhsBcastOrSplatType);
return false;
})) {
return failure();
}

// Collect the source values
// Collect the source values before broadcasting
SmallVector<Value> srcValues;
srcValues.reserve(op->getNumOperands());

for (Value operand : op->getOperands()) {
srcValues.push_back(
operand.getDefiningOp<vector::BroadcastOp>().getOperand());
srcValues.push_back(operand.getDefiningOp()->getOperand(0));
}

// Create the "elementwise" Op
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
firstOpType, op->getAttrs());
lhsBcastOrSplatType, op->getAttrs());

// Replace the original Op with the elementwise Op
auto vectorType = op->getResultTypes()[0];
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
op, vectorType, elementwiseOp->getResults());
Expand Down
39 changes: 34 additions & 5 deletions mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s

// CHECK-LABEL: func.func @broadcast_scalar(
// CHECK-LABEL: func.func @broadcast_scalar_with_bcast(
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
// CHECK: return %[[BCAST]] : vector<1x4xindex>
// CHECK: }

func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> {
func.func @broadcast_scalar_with_bcast( %arg1: index, %arg2: index) -> vector<1x4xindex> {
%0 = vector.broadcast %arg1 : index to vector<1x4xindex>
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
%2 = arith.addi %0, %1 : vector<1x4xindex>
Expand All @@ -16,20 +15,51 @@ func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> {

// -----

// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat(
// CHECK-SAME: %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: index) -> vector<1x4xindex> {
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
// CHECK: return %[[BCAST]] : vector<1x4xindex>
func.func @broadcast_scalar_with_bcast_and_splat( %arg1: index, %arg2: index) -> vector<1x4xindex> {
%0 = vector.splat %arg1 : vector<1x4xindex>
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
%2 = arith.addi %0, %1 : vector<1x4xindex>
return %2 : vector<1x4xindex>
}

// -----

// CHECK-LABEL: func.func @broadcast_vector(
// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>,
// CHECK-SAME: %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> {
// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
// CHECK: return %[[BCAST]] : vector<3x4xf32>
// CHECK: }

func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
%arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
%arg2_bcast = vector.broadcast %arg2 : vector<4xf32> to vector<3x4xf32>
%2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
return %2 : vector<3x4xf32>
}

// -----

// CHECK-LABEL: func.func @broadcast_scalar_and_vec(
// CHECK-SAME: %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: vector<4xindex>) -> vector<1x4xindex> {
// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x4xindex>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<4xindex> to vector<1x4xindex>
// CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x4xindex>
// CHECK: return %[[ADD]] : vector<1x4xindex>
func.func @broadcast_scalar_and_vec( %arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
%0 = vector.splat %arg1 : vector<1x4xindex>
%1 = vector.broadcast %arg2 : vector<4xindex> to vector<1x4xindex>
%2 = arith.addi %0, %1 : vector<1x4xindex>
return %2 : vector<1x4xindex>
}

// -----

// CHECK-LABEL: func.func @broadcast_vector_and_scalar(
Expand All @@ -38,7 +68,6 @@ func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vect
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<4xi32>
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
// CHECK: return %[[ADD]] : vector<4xi32>
// CHECK: }

func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
%arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>
Expand Down

0 comments on commit 59fbba9

Please sign in to comment.