Skip to content

Commit

Permalink
[MLIR] Add canonicalizations to all eligible index binary ops (llvm…
Browse files Browse the repository at this point in the history
…#114000)

Generalizes the following canonicalization pattern to all associative
and commutative binary ops in the `index` dialect.

```
x = v + c1
y = x + c2
   -->
y = x + (c1 + c2)
```

This includes:
- `AddOp`
- `MulOp`
- `MaxSOp`
- `MaxUOp`
- `MinSOp`
- `MinUOp`
- `AndOp`
- `OrOp`
- `XOrOp`

The operation folding is implemented using the existing folders since
`createAndFold` is used in the canonicalization.
  • Loading branch information
nacgarg authored and PhilippRados committed Nov 6, 2024
1 parent c7547e2 commit 6242552
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 23 deletions.
16 changes: 16 additions & 0 deletions mlir/include/mlir/Dialect/Index/IR/IndexOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def Index_MulOp : IndexBinaryOp<"mul", [Commutative, Pure]> {
%c = index.mul %a, %b
```
}];

let hasCanonicalizeMethod = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -263,6 +265,8 @@ def Index_MaxSOp : IndexBinaryOp<"maxs", [Commutative, Pure]> {
%c = index.maxs %a, %b
```
}];

let hasCanonicalizeMethod = 1;
}

//===----------------------------------------------------------------------===//
Expand All @@ -283,6 +287,8 @@ def Index_MaxUOp : IndexBinaryOp<"maxu", [Commutative, Pure]> {
%c = index.maxu %a, %b
```
}];

let hasCanonicalizeMethod = 1;
}

//===----------------------------------------------------------------------===//
Expand All @@ -302,6 +308,8 @@ def Index_MinSOp : IndexBinaryOp<"mins", [Commutative, Pure]> {
%c = index.mins %a, %b
```
}];

let hasCanonicalizeMethod = 1;
}

//===----------------------------------------------------------------------===//
Expand All @@ -322,6 +330,8 @@ def Index_MinUOp : IndexBinaryOp<"minu", [Commutative, Pure]> {
%c = index.minu %a, %b
```
}];

let hasCanonicalizeMethod = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -404,6 +414,8 @@ def Index_AndOp : IndexBinaryOp<"and", [Commutative, Pure]> {
%c = index.and %a, %b
```
}];

let hasCanonicalizeMethod = 1;
}

//===----------------------------------------------------------------------===//
Expand All @@ -423,6 +435,8 @@ def Index_OrOp : IndexBinaryOp<"or", [Commutative, Pure]> {
%c = index.or %a, %b
```
}];

let hasCanonicalizeMethod = 1;
}

//===----------------------------------------------------------------------===//
Expand All @@ -442,6 +456,8 @@ def Index_XOrOp : IndexBinaryOp<"xor", [Commutative, Pure]> {
%c = index.xor %a, %b
```
}];

let hasCanonicalizeMethod = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
80 changes: 60 additions & 20 deletions mlir/lib/Dialect/Index/IR/IndexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,32 @@ static OpFoldResult foldBinaryOpChecked(
return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64);
}

/// Helper for associative and commutative binary ops that can be transformed:
/// `x = op(v, c1); y = op(x, c2)` -> `tmp = op(c1, c2); y = op(v, tmp)`
/// where c1 and c2 are constants. It is expected that `tmp` will be folded.
template <typename BinaryOp>
LogicalResult
canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op,
PatternRewriter &rewriter) {
if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant()))
return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");

auto lhsOp = op.getLhs().template getDefiningOp<BinaryOp>();
if (!lhsOp)
return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not the same BinaryOp");

if (!mlir::matchPattern(lhsOp.getRhs(), mlir::m_Constant()))
return rewriter.notifyMatchFailure(op.getLoc(), "RHS of LHS op is not a constant");

Value c = rewriter.createOrFold<BinaryOp>(op->getLoc(), op.getRhs(),
lhsOp.getRhs());
if (c.getDefiningOp<BinaryOp>())
return rewriter.notifyMatchFailure(op.getLoc(), "new BinaryOp was not folded");

rewriter.replaceOpWithNewOp<BinaryOp>(op, lhsOp.getLhs(), c);
return success();
}

//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
Expand All @@ -136,27 +162,9 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {

return {};
}
/// Canonicalize
/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
IntegerAttr c1, c2;
if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c1)))
return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");

auto add = op.getLhs().getDefiningOp<mlir::index::AddOp>();
if (!add)
return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");

if (!mlir::matchPattern(add.getRhs(), mlir::m_Constant(&c2)))
return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");

auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
c1.getInt() + c2.getInt());
auto newAdd =
rewriter.create<mlir::index::AddOp>(op->getLoc(), add.getLhs(), c);

rewriter.replaceOp(op, newAdd);
return success();
LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -200,6 +208,10 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
return {};
}

LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
}

//===----------------------------------------------------------------------===//
// DivSOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -352,6 +364,10 @@ OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
});
}

LogicalResult MaxSOp::canonicalize(MaxSOp op, PatternRewriter &rewriter) {
return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
}

//===----------------------------------------------------------------------===//
// MaxUOp
//===----------------------------------------------------------------------===//
Expand All @@ -363,6 +379,10 @@ OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
});
}

LogicalResult MaxUOp::canonicalize(MaxUOp op, PatternRewriter &rewriter) {
return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
}

//===----------------------------------------------------------------------===//
// MinSOp
//===----------------------------------------------------------------------===//
Expand All @@ -374,6 +394,10 @@ OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
});
}

LogicalResult MinSOp::canonicalize(MinSOp op, PatternRewriter &rewriter) {
return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
}

//===----------------------------------------------------------------------===//
// MinUOp
//===----------------------------------------------------------------------===//
Expand All @@ -385,6 +409,10 @@ OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
});
}

LogicalResult MinUOp::canonicalize(MinUOp op, PatternRewriter &rewriter) {
return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
}

//===----------------------------------------------------------------------===//
// ShlOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -442,6 +470,10 @@ OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
[](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
}

LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
}

//===----------------------------------------------------------------------===//
// OrOp
//===----------------------------------------------------------------------===//
Expand All @@ -452,6 +484,10 @@ OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
[](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
}

LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
}

//===----------------------------------------------------------------------===//
// XOrOp
//===----------------------------------------------------------------------===//
Expand All @@ -462,6 +498,10 @@ OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
[](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
}

LogicalResult XOrOp::canonicalize(XOrOp op, PatternRewriter &rewriter) {
return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
}

//===----------------------------------------------------------------------===//
// CastSOp
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 6242552

Please sign in to comment.