diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 3b308716c84dc..c7bc944c2db18 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -896,6 +896,18 @@ OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) { if (Value result = foldAndIofAndI(*this)) return result; + /// and(a, or(a, b)) -> a + for (int i = 0; i < 2; i++) { + auto a = getOperand(1 - i); + if (auto orOp = getOperand(i).getDefiningOp()) { + for (int j = 0; j < 2; j++) { + if (orOp->getOperand(j) == a) { + return a; + } + } + } + } + return constFoldBinaryOp( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) & b; }); diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index d62c5b18fd041..fb29e72ef9b94 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2901,6 +2901,15 @@ func.func @andand3(%a : i32, %b : i32) -> i32 { return %res : i32 } +// CHECK-LABEL: @andor +// CHECK-SAME: (%[[A:.*]]: i32, %[[B:.*]]: i32) +// CHECK: return %[[A]] +func.func @andor(%a : i32, %b : i32) -> i32 { + %c = arith.ori %a, %b : i32 + %res = arith.andi %a, %b : i32 + return %res : i32 +} + // ----- // CHECK-LABEL: @truncIShrSIToTrunciShrUI