-
Notifications
You must be signed in to change notification settings - Fork 12.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[InstCombine] Use KnownBits predicate helpers #115874
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Nikita Popov (nikic) ChangesInside foldICmpUsingKnownBits(), instead of rolling our own logic based on min/max values, make use of KnownBits::eq() etc. This gives better results for the equality predicates. I've adjusted some tests to prevent the new fold from triggering, to retain their original intent of testing constant expressions. Full diff: https://github.com/llvm/llvm-project/pull/115874.diff 5 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 5a8814dfd6b3d3..975abf027f6c54 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6544,6 +6544,35 @@ bool InstCombinerImpl::replacedSelectWithOperand(SelectInst *SI,
return false;
}
+static std::optional<bool> compareKnownBits(ICmpInst::Predicate Pred,
+ const KnownBits &Op0,
+ const KnownBits &Op1) {
+ switch (Pred) {
+ case ICmpInst::ICMP_EQ:
+ return KnownBits::eq(Op0, Op1);
+ case ICmpInst::ICMP_NE:
+ return KnownBits::ne(Op0, Op1);
+ case ICmpInst::ICMP_ULT:
+ return KnownBits::ult(Op0, Op1);
+ case ICmpInst::ICMP_ULE:
+ return KnownBits::ule(Op0, Op1);
+ case ICmpInst::ICMP_UGT:
+ return KnownBits::ugt(Op0, Op1);
+ case ICmpInst::ICMP_UGE:
+ return KnownBits::uge(Op0, Op1);
+ case ICmpInst::ICMP_SLT:
+ return KnownBits::slt(Op0, Op1);
+ case ICmpInst::ICMP_SLE:
+ return KnownBits::sle(Op0, Op1);
+ case ICmpInst::ICMP_SGT:
+ return KnownBits::sgt(Op0, Op1);
+ case ICmpInst::ICMP_SGE:
+ return KnownBits::sge(Op0, Op1);
+ default:
+ llvm_unreachable("Unknown predicate");
+ }
+}
+
/// Try to fold the comparison based on range information we can get by checking
/// whether bits are known to be zero or one in the inputs.
Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
@@ -6576,6 +6605,16 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
return &I;
}
+ if (!isa<Constant>(Op0) && Op0Known.isConstant())
+ return new ICmpInst(
+ Pred, ConstantExpr::getIntegerValue(Ty, Op0Known.getConstant()), Op1);
+ if (!isa<Constant>(Op1) && Op1Known.isConstant())
+ return new ICmpInst(
+ Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Known.getConstant()));
+
+ if (std::optional<bool> Res = compareKnownBits(Pred, Op0Known, Op1Known))
+ return replaceInstUsesWith(I, ConstantInt::getBool(I.getType(), *Res));
+
// Given the known and unknown bits, compute a range that the LHS could be
// in. Compute the Min, Max and RHS values based on the known bits. For the
// EQ and NE we use unsigned values.
@@ -6593,14 +6632,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
Op1Max = Op1Known.getMaxValue();
}
- // If Min and Max are known to be the same, then SimplifyDemandedBits figured
- // out that the LHS or RHS is a constant. Constant fold this now, so that
- // code below can assume that Min != Max.
- if (!isa<Constant>(Op0) && Op0Min == Op0Max)
- return new ICmpInst(Pred, ConstantExpr::getIntegerValue(Ty, Op0Min), Op1);
- if (!isa<Constant>(Op1) && Op1Min == Op1Max)
- return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min));
-
// Don't break up a clamp pattern -- (min(max X, Y), Z) -- by replacing a
// min/max canonical compare with some other compare. That could lead to
// conflict with select canonicalization and infinite looping.
@@ -6682,13 +6713,9 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
// simplify this comparison. For example, (x&4) < 8 is always true.
switch (Pred) {
default:
- llvm_unreachable("Unknown icmp opcode!");
+ break;
case ICmpInst::ICMP_EQ:
case ICmpInst::ICMP_NE: {
- if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
- return replaceInstUsesWith(
- I, ConstantInt::getBool(I.getType(), Pred == CmpInst::ICMP_NE));
-
// If all bits are known zero except for one, then we know at most one bit
// is set. If the comparison is against zero, then this is a check to see if
// *that* bit is set.
@@ -6728,67 +6755,19 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
ConstantInt::getNullValue(Op1->getType()));
break;
}
- case ICmpInst::ICMP_ULT: {
- if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- break;
- }
- case ICmpInst::ICMP_UGT: {
- if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- break;
- }
- case ICmpInst::ICMP_SLT: {
- if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- break;
- }
- case ICmpInst::ICMP_SGT: {
- if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- break;
- }
case ICmpInst::ICMP_SGE:
- assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!");
- if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Min == Op0Max) // A >=s B -> A == B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
case ICmpInst::ICMP_SLE:
- assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!");
- if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Max == Op0Min) // A <=s B -> A == B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
case ICmpInst::ICMP_UGE:
- assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!");
- if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Min == Op0Max) // A >=u B -> A == B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
case ICmpInst::ICMP_ULE:
- assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!");
- if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Max == Op0Min) // A <=u B -> A == B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
diff --git a/llvm/test/Transforms/InstCombine/icmp-gep.ll b/llvm/test/Transforms/InstCombine/icmp-gep.ll
index 887cf1162319bc..776716fe908733 100644
--- a/llvm/test/Transforms/InstCombine/icmp-gep.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-gep.ll
@@ -583,9 +583,7 @@ define i1 @gep_nusw(ptr %p, i64 %a, i64 %b, i64 %c, i64 %d) {
define i1 @pointer_icmp_aligned_with_offset(ptr align 8 %a, ptr align 8 %a2) {
; CHECK-LABEL: @pointer_icmp_aligned_with_offset(
-; CHECK-NEXT: [[GEP:%.*]] = getelementptr i8, ptr [[A:%.*]], i64 4
-; CHECK-NEXT: [[CMP:%.*]] = icmp eq ptr [[GEP]], [[A2:%.*]]
-; CHECK-NEXT: ret i1 [[CMP]]
+; CHECK-NEXT: ret i1 false
;
%gep = getelementptr i8, ptr %a, i64 4
%cmp = icmp eq ptr %gep, %a2
diff --git a/llvm/test/Transforms/InstCombine/mul-inseltpoison.ll b/llvm/test/Transforms/InstCombine/mul-inseltpoison.ll
index 997758af62a543..8baf6a70fdd5d1 100644
--- a/llvm/test/Transforms/InstCombine/mul-inseltpoison.ll
+++ b/llvm/test/Transforms/InstCombine/mul-inseltpoison.ll
@@ -570,12 +570,12 @@ define i64 @test30(i32 %A, i32 %B) {
@PR22087 = external global i32
define i32 @test31(i32 %V) {
; CHECK-LABEL: @test31(
-; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
+; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[CMP]] to i32
; CHECK-NEXT: [[MUL1:%.*]] = shl i32 [[V:%.*]], [[EXT]]
; CHECK-NEXT: ret i32 [[MUL1]]
;
- %cmp = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
+ %cmp = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
%ext = zext i1 %cmp to i32
%shl = shl i32 1, %ext
%mul = mul i32 %V, %shl
diff --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll
index e38ab1b9622b2c..340828a8d3f9dd 100644
--- a/llvm/test/Transforms/InstCombine/mul.ll
+++ b/llvm/test/Transforms/InstCombine/mul.ll
@@ -1152,12 +1152,12 @@ define i64 @test30(i32 %A, i32 %B) {
@PR22087 = external global i32
define i32 @test31(i32 %V) {
; CHECK-LABEL: @test31(
-; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
+; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[CMP]] to i32
; CHECK-NEXT: [[MUL1:%.*]] = shl i32 [[V:%.*]], [[EXT]]
; CHECK-NEXT: ret i32 [[MUL1]]
;
- %cmp = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
+ %cmp = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
%ext = zext i1 %cmp to i32
%shl = shl i32 1, %ext
%mul = mul i32 %V, %shl
diff --git a/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll b/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll
index 070a3b03302124..e95955da1b8728 100644
--- a/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll
+++ b/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll
@@ -669,7 +669,7 @@ define <2 x i1> @n38_overshift(<2 x i32> %x, <2 x i32> %y) {
}
; As usual, don't crash given constantexpr's :/
-@f.a = internal global i16 0
+@f.a = internal global i16 0, align 1
define i1 @constantexpr() {
; CHECK-LABEL: @constantexpr(
; CHECK-NEXT: entry:
|
This may have a bit of compile-time impact (it's close to the noise threshold): https://llvm-compile-time-tracker.com/compare.php?from=36f21eedcfd06ff97ead9625adbf6d8153edd233&to=a371228364bfdec8864e4a3af655d45c8700a87f&stat=instructions:u If preferred I could only use |
The reason this only makes a different for pointers is that simplifyICmpInst calls isKnownNonEqual, which does a KnownBits based check for integers, but not pointers. So an alternative fix would be to adjust isKnownNonEqual() to handle this for pointers as well. |
The two don't seem mutually exclusive. This patch has value IMO even if it's just dropping some bespoke logic for better tested/implemented apis |
Generally agree. However, I just tested the isKnownNonEqual change, and that one has more significant compile-time impact: http://llvm-compile-time-tracker.com/compare.php?from=6d8d9fc8d279623cca94b2b875a92517ed308f18&to=49f63dc5015abe32735e99c39cf86bab32a7a19a&stat=instructions:u It's adding 0.2% to clang thin link. So I'm inclined to go with this patch only. |
Inside foldICmpUsingKnownBits(), instead of rolling our own logic based on min/max values, make use of KnownBits::eq() etc. This gives better results for the equality predicates. I've adjusted some tests to prevent the new fold from triggering, to retain their original intent of testing constant expressions.
bd61adb
to
136425b
Compare
Fair enough, is there no other independent value to IsKnownNonEqual for ptrs? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG
Inside foldICmpUsingKnownBits(), instead of rolling our own logic based on min/max values, make use of ICmpInst::compare() working on KnownBits. This gives better results for the equality predicates. In practice, the improvement is only for pointers, because isKnownNonEqual() handles the non-pointer case. I've adjusted some tests to prevent the new fold from triggering, to retain their original intent of testing constant expressions.
Inside foldICmpUsingKnownBits(), instead of rolling our own logic based on min/max values, make use of ICmpInst::compare() working on KnownBits. This gives better results for the equality predicates. In practice, the improvement is only for pointers, because isKnownNonEqual() handles the non-pointer case.
I've adjusted some tests to prevent the new fold from triggering, to retain their original intent of testing constant expressions.