diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td index ce1355316b09b8..230a3815bdd81e 100644 --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td @@ -95,6 +95,8 @@ def Index_MulOp : IndexBinaryOp<"mul", [Commutative, Pure]> { %c = index.mul %a, %b ``` }]; + + let hasCanonicalizeMethod = 1; } //===----------------------------------------------------------------------===// @@ -263,6 +265,8 @@ def Index_MaxSOp : IndexBinaryOp<"maxs", [Commutative, Pure]> { %c = index.maxs %a, %b ``` }]; + + let hasCanonicalizeMethod = 1; } //===----------------------------------------------------------------------===// @@ -283,6 +287,8 @@ def Index_MaxUOp : IndexBinaryOp<"maxu", [Commutative, Pure]> { %c = index.maxu %a, %b ``` }]; + + let hasCanonicalizeMethod = 1; } //===----------------------------------------------------------------------===// @@ -302,6 +308,8 @@ def Index_MinSOp : IndexBinaryOp<"mins", [Commutative, Pure]> { %c = index.mins %a, %b ``` }]; + + let hasCanonicalizeMethod = 1; } //===----------------------------------------------------------------------===// @@ -322,6 +330,8 @@ def Index_MinUOp : IndexBinaryOp<"minu", [Commutative, Pure]> { %c = index.minu %a, %b ``` }]; + + let hasCanonicalizeMethod = 1; } //===----------------------------------------------------------------------===// @@ -404,6 +414,8 @@ def Index_AndOp : IndexBinaryOp<"and", [Commutative, Pure]> { %c = index.and %a, %b ``` }]; + + let hasCanonicalizeMethod = 1; } //===----------------------------------------------------------------------===// @@ -423,6 +435,8 @@ def Index_OrOp : IndexBinaryOp<"or", [Commutative, Pure]> { %c = index.or %a, %b ``` }]; + + let hasCanonicalizeMethod = 1; } //===----------------------------------------------------------------------===// @@ -442,6 +456,8 @@ def Index_XOrOp : IndexBinaryOp<"xor", [Commutative, Pure]> { %c = index.xor %a, %b ``` }]; + + let hasCanonicalizeMethod = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp index 0b58eb80f93032..5c935c5f4b53e3 100644 --- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp @@ -118,6 +118,32 @@ static OpFoldResult foldBinaryOpChecked( return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64); } +/// Helper for associative and commutative binary ops that can be transformed: +/// `x = op(v, c1); y = op(x, c2)` -> `tmp = op(c1, c2); y = op(v, tmp)` +/// where c1 and c2 are constants. It is expected that `tmp` will be folded. +template +LogicalResult +canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op, + PatternRewriter &rewriter) { + if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant())) + return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant"); + + auto lhsOp = op.getLhs().template getDefiningOp(); + if (!lhsOp) + return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not the same BinaryOp"); + + if (!mlir::matchPattern(lhsOp.getRhs(), mlir::m_Constant())) + return rewriter.notifyMatchFailure(op.getLoc(), "RHS of LHS op is not a constant"); + + Value c = rewriter.createOrFold(op->getLoc(), op.getRhs(), + lhsOp.getRhs()); + if (c.getDefiningOp()) + return rewriter.notifyMatchFailure(op.getLoc(), "new BinaryOp was not folded"); + + rewriter.replaceOpWithNewOp(op, lhsOp.getLhs(), c); + return success(); +} + //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// @@ -136,27 +162,9 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) { return {}; } -/// Canonicalize -/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)` -LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) { - IntegerAttr c1, c2; - if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c1))) - return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant"); - - auto add = op.getLhs().getDefiningOp(); - if (!add) - return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add"); - - if (!mlir::matchPattern(add.getRhs(), mlir::m_Constant(&c2))) - return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant"); - - auto c = rewriter.create(op->getLoc(), - c1.getInt() + c2.getInt()); - auto newAdd = - rewriter.create(op->getLoc(), add.getLhs(), c); - rewriter.replaceOp(op, newAdd); - return success(); +LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) { + return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); } //===----------------------------------------------------------------------===// @@ -200,6 +208,10 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { return {}; } +LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) { + return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); +} + //===----------------------------------------------------------------------===// // DivSOp //===----------------------------------------------------------------------===// @@ -352,6 +364,10 @@ OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) { }); } +LogicalResult MaxSOp::canonicalize(MaxSOp op, PatternRewriter &rewriter) { + return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); +} + //===----------------------------------------------------------------------===// // MaxUOp //===----------------------------------------------------------------------===// @@ -363,6 +379,10 @@ OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) { }); } +LogicalResult MaxUOp::canonicalize(MaxUOp op, PatternRewriter &rewriter) { + return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); +} + //===----------------------------------------------------------------------===// // MinSOp //===----------------------------------------------------------------------===// @@ -374,6 +394,10 @@ OpFoldResult MinSOp::fold(FoldAdaptor adaptor) { }); } +LogicalResult MinSOp::canonicalize(MinSOp op, PatternRewriter &rewriter) { + return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); +} + //===----------------------------------------------------------------------===// // MinUOp //===----------------------------------------------------------------------===// @@ -385,6 +409,10 @@ OpFoldResult MinUOp::fold(FoldAdaptor adaptor) { }); } +LogicalResult MinUOp::canonicalize(MinUOp op, PatternRewriter &rewriter) { + return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); +} + //===----------------------------------------------------------------------===// // ShlOp //===----------------------------------------------------------------------===// @@ -442,6 +470,10 @@ OpFoldResult AndOp::fold(FoldAdaptor adaptor) { [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; }); } +LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) { + return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); +} + //===----------------------------------------------------------------------===// // OrOp //===----------------------------------------------------------------------===// @@ -452,6 +484,10 @@ OpFoldResult OrOp::fold(FoldAdaptor adaptor) { [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; }); } +LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) { + return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); +} + //===----------------------------------------------------------------------===// // XOrOp //===----------------------------------------------------------------------===// @@ -462,6 +498,10 @@ OpFoldResult XOrOp::fold(FoldAdaptor adaptor) { [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; }); } +LogicalResult XOrOp::canonicalize(XOrOp op, PatternRewriter &rewriter) { + return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); +} + //===----------------------------------------------------------------------===// // CastSOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir index a29b09c11f7f62..45da6ea57d796e 100644 --- a/mlir/test/Dialect/Index/index-canonicalize.mlir +++ b/mlir/test/Dialect/Index/index-canonicalize.mlir @@ -32,15 +32,15 @@ func.func @add_overflow() -> (index, index) { return %2, %3 : index, index } -// CHECK-LABEL: @add +// CHECK-LABEL: @add_fold_constants func.func @add_fold_constants(%arg: index) -> (index) { %0 = index.constant 1 %1 = index.constant 2 %2 = index.add %arg, %0 %3 = index.add %2, %1 - // CHECK-DAG: [[C3:%.*]] = index.constant 3 - // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[C3]] + // CHECK: [[C3:%.*]] = index.constant 3 + // CHECK: [[V0:%.*]] = index.add %arg0, [[C3]] // CHECK: return [[V0]] return %3 : index } @@ -65,6 +65,19 @@ func.func @mul() -> index { return %2 : index } +// CHECK-LABEL: @mul_fold_constants +func.func @mul_fold_constants(%arg: index) -> (index) { + %0 = index.constant 2 + %1 = index.constant 3 + %2 = index.mul %arg, %0 + %3 = index.mul %2, %1 + + // CHECK: [[C6:%.*]] = index.constant 6 + // CHECK: [[V0:%.*]] = index.mul %arg0, [[C6]] + // CHECK: return [[V0]] + return %3 : index +} + // CHECK-LABEL: @divs func.func @divs() -> index { %0 = index.constant -2 @@ -300,6 +313,19 @@ func.func @maxs_edge() -> index { return %0 : index } +// CHECK-LABEL: @maxs_fold_constants +func.func @maxs_fold_constants(%arg: index) -> (index) { + %0 = index.constant -2 + %1 = index.constant 3 + %2 = index.maxs %arg, %0 + %3 = index.maxs %2, %1 + + // CHECK: [[C3:%.*]] = index.constant 3 + // CHECK: [[V0:%.*]] = index.maxs %arg0, [[C3]] + // CHECK: return [[V0]] + return %3 : index +} + // CHECK-LABEL: @maxu func.func @maxu() -> index { %lhs = index.constant -1 @@ -310,6 +336,19 @@ func.func @maxu() -> index { return %0 : index } +// CHECK-LABEL: @maxu_fold_constants +func.func @maxu_fold_constants(%arg: index) -> (index) { + %0 = index.constant 2 + %1 = index.constant 3 + %2 = index.maxu %arg, %0 + %3 = index.maxu %2, %1 + + // CHECK: [[C3:%.*]] = index.constant 3 + // CHECK: [[V0:%.*]] = index.maxu %arg0, [[C3]] + // CHECK: return [[V0]] + return %3 : index +} + // CHECK-LABEL: @mins func.func @mins() -> index { %lhs = index.constant -4 @@ -340,6 +379,19 @@ func.func @mins_nofold_2() -> index { return %0 : index } +// CHECK-LABEL: @mins_fold_constants +func.func @mins_fold_constants(%arg: index) -> (index) { + %0 = index.constant -2 + %1 = index.constant 3 + %2 = index.mins %arg, %0 + %3 = index.mins %2, %1 + + // CHECK: [[C2:%.*]] = index.constant -2 + // CHECK: [[V0:%.*]] = index.mins %arg0, [[C2]] + // CHECK: return [[V0]] + return %3 : index +} + // CHECK-LABEL: @minu func.func @minu() -> index { %lhs = index.constant -1 @@ -350,6 +402,19 @@ func.func @minu() -> index { return %0 : index } +// CHECK-LABEL: @minu_fold_constants +func.func @minu_fold_constants(%arg: index) -> (index) { + %0 = index.constant 2 + %1 = index.constant 3 + %2 = index.minu %arg, %0 + %3 = index.minu %2, %1 + + // CHECK: [[C2:%.*]] = index.constant 2 + // CHECK: [[V0:%.*]] = index.minu %arg0, [[C2]] + // CHECK: return [[V0]] + return %3 : index +} + // CHECK-LABEL: @shl func.func @shl() -> index { %lhs = index.constant 128 @@ -465,6 +530,19 @@ func.func @and() -> index { return %0 : index } +// CHECK-LABEL: @and_fold_constants +func.func @and_fold_constants(%arg: index) -> (index) { + %0 = index.constant 5 + %1 = index.constant 1 + %2 = index.and %arg, %0 + %3 = index.and %2, %1 + + // CHECK: [[C1:%.*]] = index.constant 1 + // CHECK: [[V0:%.*]] = index.and %arg0, [[C1]] + // CHECK: return [[V0]] + return %3 : index +} + // CHECK-LABEL: @or func.func @or() -> index { %lhs = index.constant 5 @@ -475,6 +553,19 @@ func.func @or() -> index { return %0 : index } +// CHECK-LABEL: @or_fold_constants +func.func @or_fold_constants(%arg: index) -> (index) { + %0 = index.constant 5 + %1 = index.constant 1 + %2 = index.or %arg, %0 + %3 = index.or %2, %1 + + // CHECK: [[C5:%.*]] = index.constant 5 + // CHECK: [[V0:%.*]] = index.or %arg0, [[C5]] + // CHECK: return [[V0]] + return %3 : index +} + // CHECK-LABEL: @xor func.func @xor() -> index { %lhs = index.constant 5 @@ -485,6 +576,19 @@ func.func @xor() -> index { return %0 : index } +// CHECK-LABEL: @xor_fold_constants +func.func @xor_fold_constants(%arg: index) -> (index) { + %0 = index.constant 5 + %1 = index.constant 1 + %2 = index.xor %arg, %0 + %3 = index.xor %2, %1 + + // CHECK: [[C4:%.*]] = index.constant 4 + // CHECK: [[V0:%.*]] = index.xor %arg0, [[C4]] + // CHECK: return [[V0]] + return %3 : index +} + // CHECK-LABEL: @cmp func.func @cmp(%arg0: index) -> (i1, i1, i1, i1, i1, i1) { %a = index.constant 0