From 82920a32485f7a6cf2b4b311d5dd36b4e91f3eae Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Tue, 12 Nov 2024 15:02:15 +0100 Subject: [PATCH 1/2] [InstCombine] Use KnownBits predicate helpers Inside foldICmpUsingKnownBits(), instead of rolling our own logic based on min/max values, make use of KnownBits::eq() etc. This gives better results for the equality predicates. I've adjusted some tests to prevent the new fold from triggering, to retain their original intent of testing constant expressions. --- .../InstCombine/InstCombineCompares.cpp | 101 +++++++----------- llvm/test/Transforms/InstCombine/icmp-gep.ll | 4 +- .../InstCombine/mul-inseltpoison.ll | 4 +- llvm/test/Transforms/InstCombine/mul.ll | 4 +- .../shift-amount-reassociation-in-bittest.ll | 2 +- 5 files changed, 46 insertions(+), 69 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 5a8814dfd6b3d3..975abf027f6c54 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -6544,6 +6544,35 @@ bool InstCombinerImpl::replacedSelectWithOperand(SelectInst *SI, return false; } +static std::optional compareKnownBits(ICmpInst::Predicate Pred, + const KnownBits &Op0, + const KnownBits &Op1) { + switch (Pred) { + case ICmpInst::ICMP_EQ: + return KnownBits::eq(Op0, Op1); + case ICmpInst::ICMP_NE: + return KnownBits::ne(Op0, Op1); + case ICmpInst::ICMP_ULT: + return KnownBits::ult(Op0, Op1); + case ICmpInst::ICMP_ULE: + return KnownBits::ule(Op0, Op1); + case ICmpInst::ICMP_UGT: + return KnownBits::ugt(Op0, Op1); + case ICmpInst::ICMP_UGE: + return KnownBits::uge(Op0, Op1); + case ICmpInst::ICMP_SLT: + return KnownBits::slt(Op0, Op1); + case ICmpInst::ICMP_SLE: + return KnownBits::sle(Op0, Op1); + case ICmpInst::ICMP_SGT: + return KnownBits::sgt(Op0, Op1); + case ICmpInst::ICMP_SGE: + return KnownBits::sge(Op0, Op1); + default: + llvm_unreachable("Unknown predicate"); + } +} + /// Try to fold the comparison based on range information we can get by checking /// whether bits are known to be zero or one in the inputs. Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { @@ -6576,6 +6605,16 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { return &I; } + if (!isa(Op0) && Op0Known.isConstant()) + return new ICmpInst( + Pred, ConstantExpr::getIntegerValue(Ty, Op0Known.getConstant()), Op1); + if (!isa(Op1) && Op1Known.isConstant()) + return new ICmpInst( + Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Known.getConstant())); + + if (std::optional Res = compareKnownBits(Pred, Op0Known, Op1Known)) + return replaceInstUsesWith(I, ConstantInt::getBool(I.getType(), *Res)); + // Given the known and unknown bits, compute a range that the LHS could be // in. Compute the Min, Max and RHS values based on the known bits. For the // EQ and NE we use unsigned values. @@ -6593,14 +6632,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { Op1Max = Op1Known.getMaxValue(); } - // If Min and Max are known to be the same, then SimplifyDemandedBits figured - // out that the LHS or RHS is a constant. Constant fold this now, so that - // code below can assume that Min != Max. - if (!isa(Op0) && Op0Min == Op0Max) - return new ICmpInst(Pred, ConstantExpr::getIntegerValue(Ty, Op0Min), Op1); - if (!isa(Op1) && Op1Min == Op1Max) - return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min)); - // Don't break up a clamp pattern -- (min(max X, Y), Z) -- by replacing a // min/max canonical compare with some other compare. That could lead to // conflict with select canonicalization and infinite looping. @@ -6682,13 +6713,9 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { // simplify this comparison. For example, (x&4) < 8 is always true. switch (Pred) { default: - llvm_unreachable("Unknown icmp opcode!"); + break; case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_NE: { - if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) - return replaceInstUsesWith( - I, ConstantInt::getBool(I.getType(), Pred == CmpInst::ICMP_NE)); - // If all bits are known zero except for one, then we know at most one bit // is set. If the comparison is against zero, then this is a check to see if // *that* bit is set. @@ -6728,67 +6755,19 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { ConstantInt::getNullValue(Op1->getType())); break; } - case ICmpInst::ICMP_ULT: { - if (Op0Max.ult(Op1Min)) // A true if max(A) < min(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.uge(Op1Max)) // A false if min(A) >= max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - } - case ICmpInst::ICMP_UGT: { - if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - } - case ICmpInst::ICMP_SLT: { - if (Op0Max.slt(Op1Min)) // A true if max(A) < min(C) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.sge(Op1Max)) // A false if min(A) >= max(C) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - } - case ICmpInst::ICMP_SGT: { - if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - } case ICmpInst::ICMP_SGE: - assert(!isa(Op1) && "ICMP_SGE with ConstantInt not folded!"); - if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Min == Op0Max) // A >=s B -> A == B if max(A) == min(B) return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); break; case ICmpInst::ICMP_SLE: - assert(!isa(Op1) && "ICMP_SLE with ConstantInt not folded!"); - if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Max == Op0Min) // A <=s B -> A == B if min(A) == max(B) return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); break; case ICmpInst::ICMP_UGE: - assert(!isa(Op1) && "ICMP_UGE with ConstantInt not folded!"); - if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Min == Op0Max) // A >=u B -> A == B if max(A) == min(B) return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); break; case ICmpInst::ICMP_ULE: - assert(!isa(Op1) && "ICMP_ULE with ConstantInt not folded!"); - if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Max == Op0Min) // A <=u B -> A == B if min(A) == max(B) return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); break; diff --git a/llvm/test/Transforms/InstCombine/icmp-gep.ll b/llvm/test/Transforms/InstCombine/icmp-gep.ll index 887cf1162319bc..776716fe908733 100644 --- a/llvm/test/Transforms/InstCombine/icmp-gep.ll +++ b/llvm/test/Transforms/InstCombine/icmp-gep.ll @@ -583,9 +583,7 @@ define i1 @gep_nusw(ptr %p, i64 %a, i64 %b, i64 %c, i64 %d) { define i1 @pointer_icmp_aligned_with_offset(ptr align 8 %a, ptr align 8 %a2) { ; CHECK-LABEL: @pointer_icmp_aligned_with_offset( -; CHECK-NEXT: [[GEP:%.*]] = getelementptr i8, ptr [[A:%.*]], i64 4 -; CHECK-NEXT: [[CMP:%.*]] = icmp eq ptr [[GEP]], [[A2:%.*]] -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: ret i1 false ; %gep = getelementptr i8, ptr %a, i64 4 %cmp = icmp eq ptr %gep, %a2 diff --git a/llvm/test/Transforms/InstCombine/mul-inseltpoison.ll b/llvm/test/Transforms/InstCombine/mul-inseltpoison.ll index 997758af62a543..8baf6a70fdd5d1 100644 --- a/llvm/test/Transforms/InstCombine/mul-inseltpoison.ll +++ b/llvm/test/Transforms/InstCombine/mul-inseltpoison.ll @@ -570,12 +570,12 @@ define i64 @test30(i32 %A, i32 %B) { @PR22087 = external global i32 define i32 @test31(i32 %V) { ; CHECK-LABEL: @test31( -; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087 +; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087 ; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[CMP]] to i32 ; CHECK-NEXT: [[MUL1:%.*]] = shl i32 [[V:%.*]], [[EXT]] ; CHECK-NEXT: ret i32 [[MUL1]] ; - %cmp = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087 + %cmp = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087 %ext = zext i1 %cmp to i32 %shl = shl i32 1, %ext %mul = mul i32 %V, %shl diff --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll index e38ab1b9622b2c..340828a8d3f9dd 100644 --- a/llvm/test/Transforms/InstCombine/mul.ll +++ b/llvm/test/Transforms/InstCombine/mul.ll @@ -1152,12 +1152,12 @@ define i64 @test30(i32 %A, i32 %B) { @PR22087 = external global i32 define i32 @test31(i32 %V) { ; CHECK-LABEL: @test31( -; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087 +; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087 ; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[CMP]] to i32 ; CHECK-NEXT: [[MUL1:%.*]] = shl i32 [[V:%.*]], [[EXT]] ; CHECK-NEXT: ret i32 [[MUL1]] ; - %cmp = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087 + %cmp = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087 %ext = zext i1 %cmp to i32 %shl = shl i32 1, %ext %mul = mul i32 %V, %shl diff --git a/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll b/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll index 070a3b03302124..e95955da1b8728 100644 --- a/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll +++ b/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll @@ -669,7 +669,7 @@ define <2 x i1> @n38_overshift(<2 x i32> %x, <2 x i32> %y) { } ; As usual, don't crash given constantexpr's :/ -@f.a = internal global i16 0 +@f.a = internal global i16 0, align 1 define i1 @constantexpr() { ; CHECK-LABEL: @constantexpr( ; CHECK-NEXT: entry: From 136425bbc2ea761599d7f4b275d5af0dfa28607c Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Tue, 12 Nov 2024 17:42:40 +0100 Subject: [PATCH 2/2] use new helper --- .../InstCombine/InstCombineCompares.cpp | 31 +------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 975abf027f6c54..d602a907e72bcd 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -6544,35 +6544,6 @@ bool InstCombinerImpl::replacedSelectWithOperand(SelectInst *SI, return false; } -static std::optional compareKnownBits(ICmpInst::Predicate Pred, - const KnownBits &Op0, - const KnownBits &Op1) { - switch (Pred) { - case ICmpInst::ICMP_EQ: - return KnownBits::eq(Op0, Op1); - case ICmpInst::ICMP_NE: - return KnownBits::ne(Op0, Op1); - case ICmpInst::ICMP_ULT: - return KnownBits::ult(Op0, Op1); - case ICmpInst::ICMP_ULE: - return KnownBits::ule(Op0, Op1); - case ICmpInst::ICMP_UGT: - return KnownBits::ugt(Op0, Op1); - case ICmpInst::ICMP_UGE: - return KnownBits::uge(Op0, Op1); - case ICmpInst::ICMP_SLT: - return KnownBits::slt(Op0, Op1); - case ICmpInst::ICMP_SLE: - return KnownBits::sle(Op0, Op1); - case ICmpInst::ICMP_SGT: - return KnownBits::sgt(Op0, Op1); - case ICmpInst::ICMP_SGE: - return KnownBits::sge(Op0, Op1); - default: - llvm_unreachable("Unknown predicate"); - } -} - /// Try to fold the comparison based on range information we can get by checking /// whether bits are known to be zero or one in the inputs. Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { @@ -6612,7 +6583,7 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { return new ICmpInst( Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Known.getConstant())); - if (std::optional Res = compareKnownBits(Pred, Op0Known, Op1Known)) + if (std::optional Res = ICmpInst::compare(Op0Known, Op1Known, Pred)) return replaceInstUsesWith(I, ConstantInt::getBool(I.getType(), *Res)); // Given the known and unknown bits, compute a range that the LHS could be