Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Use KnownBits predicate helpers #115874

Merged
merged 2 commits into from
Nov 14, 2024

Conversation

nikic
Copy link
Contributor

@nikic nikic commented Nov 12, 2024

Inside foldICmpUsingKnownBits(), instead of rolling our own logic based on min/max values, make use of ICmpInst::compare() working on KnownBits. This gives better results for the equality predicates. In practice, the improvement is only for pointers, because isKnownNonEqual() handles the non-pointer case.

I've adjusted some tests to prevent the new fold from triggering, to retain their original intent of testing constant expressions.

@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Nikita Popov (nikic)

Changes

Inside foldICmpUsingKnownBits(), instead of rolling our own logic based on min/max values, make use of KnownBits::eq() etc. This gives better results for the equality predicates.

I've adjusted some tests to prevent the new fold from triggering, to retain their original intent of testing constant expressions.


Full diff: https://github.com/llvm/llvm-project/pull/115874.diff

5 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+40-61)
  • (modified) llvm/test/Transforms/InstCombine/icmp-gep.ll (+1-3)
  • (modified) llvm/test/Transforms/InstCombine/mul-inseltpoison.ll (+2-2)
  • (modified) llvm/test/Transforms/InstCombine/mul.ll (+2-2)
  • (modified) llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll (+1-1)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 5a8814dfd6b3d3..975abf027f6c54 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6544,6 +6544,35 @@ bool InstCombinerImpl::replacedSelectWithOperand(SelectInst *SI,
   return false;
 }
 
+static std::optional<bool> compareKnownBits(ICmpInst::Predicate Pred,
+                                            const KnownBits &Op0,
+                                            const KnownBits &Op1) {
+  switch (Pred) {
+  case ICmpInst::ICMP_EQ:
+    return KnownBits::eq(Op0, Op1);
+  case ICmpInst::ICMP_NE:
+    return KnownBits::ne(Op0, Op1);
+  case ICmpInst::ICMP_ULT:
+    return KnownBits::ult(Op0, Op1);
+  case ICmpInst::ICMP_ULE:
+    return KnownBits::ule(Op0, Op1);
+  case ICmpInst::ICMP_UGT:
+    return KnownBits::ugt(Op0, Op1);
+  case ICmpInst::ICMP_UGE:
+    return KnownBits::uge(Op0, Op1);
+  case ICmpInst::ICMP_SLT:
+    return KnownBits::slt(Op0, Op1);
+  case ICmpInst::ICMP_SLE:
+    return KnownBits::sle(Op0, Op1);
+  case ICmpInst::ICMP_SGT:
+    return KnownBits::sgt(Op0, Op1);
+  case ICmpInst::ICMP_SGE:
+    return KnownBits::sge(Op0, Op1);
+  default:
+    llvm_unreachable("Unknown predicate");
+  }
+}
+
 /// Try to fold the comparison based on range information we can get by checking
 /// whether bits are known to be zero or one in the inputs.
 Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
@@ -6576,6 +6605,16 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
       return &I;
   }
 
+  if (!isa<Constant>(Op0) && Op0Known.isConstant())
+    return new ICmpInst(
+        Pred, ConstantExpr::getIntegerValue(Ty, Op0Known.getConstant()), Op1);
+  if (!isa<Constant>(Op1) && Op1Known.isConstant())
+    return new ICmpInst(
+        Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Known.getConstant()));
+
+  if (std::optional<bool> Res = compareKnownBits(Pred, Op0Known, Op1Known))
+    return replaceInstUsesWith(I, ConstantInt::getBool(I.getType(), *Res));
+
   // Given the known and unknown bits, compute a range that the LHS could be
   // in.  Compute the Min, Max and RHS values based on the known bits. For the
   // EQ and NE we use unsigned values.
@@ -6593,14 +6632,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
     Op1Max = Op1Known.getMaxValue();
   }
 
-  // If Min and Max are known to be the same, then SimplifyDemandedBits figured
-  // out that the LHS or RHS is a constant. Constant fold this now, so that
-  // code below can assume that Min != Max.
-  if (!isa<Constant>(Op0) && Op0Min == Op0Max)
-    return new ICmpInst(Pred, ConstantExpr::getIntegerValue(Ty, Op0Min), Op1);
-  if (!isa<Constant>(Op1) && Op1Min == Op1Max)
-    return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min));
-
   // Don't break up a clamp pattern -- (min(max X, Y), Z) -- by replacing a
   // min/max canonical compare with some other compare. That could lead to
   // conflict with select canonicalization and infinite looping.
@@ -6682,13 +6713,9 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
   // simplify this comparison.  For example, (x&4) < 8 is always true.
   switch (Pred) {
   default:
-    llvm_unreachable("Unknown icmp opcode!");
+    break;
   case ICmpInst::ICMP_EQ:
   case ICmpInst::ICMP_NE: {
-    if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
-      return replaceInstUsesWith(
-          I, ConstantInt::getBool(I.getType(), Pred == CmpInst::ICMP_NE));
-
     // If all bits are known zero except for one, then we know at most one bit
     // is set. If the comparison is against zero, then this is a check to see if
     // *that* bit is set.
@@ -6728,67 +6755,19 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
                           ConstantInt::getNullValue(Op1->getType()));
     break;
   }
-  case ICmpInst::ICMP_ULT: {
-    if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B)
-      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-    if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B)
-      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-    break;
-  }
-  case ICmpInst::ICMP_UGT: {
-    if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B)
-      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-    if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B)
-      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-    break;
-  }
-  case ICmpInst::ICMP_SLT: {
-    if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C)
-      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-    if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C)
-      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-    break;
-  }
-  case ICmpInst::ICMP_SGT: {
-    if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B)
-      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-    if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B)
-      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-    break;
-  }
   case ICmpInst::ICMP_SGE:
-    assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!");
-    if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B)
-      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-    if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B)
-      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
     if (Op1Min == Op0Max) // A >=s B -> A == B if max(A) == min(B)
       return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
     break;
   case ICmpInst::ICMP_SLE:
-    assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!");
-    if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B)
-      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-    if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B)
-      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
     if (Op1Max == Op0Min) // A <=s B -> A == B if min(A) == max(B)
       return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
     break;
   case ICmpInst::ICMP_UGE:
-    assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!");
-    if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B)
-      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-    if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B)
-      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
     if (Op1Min == Op0Max) // A >=u B -> A == B if max(A) == min(B)
       return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
     break;
   case ICmpInst::ICMP_ULE:
-    assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!");
-    if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B)
-      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-    if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B)
-      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
     if (Op1Max == Op0Min) // A <=u B -> A == B if min(A) == max(B)
       return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
     break;
diff --git a/llvm/test/Transforms/InstCombine/icmp-gep.ll b/llvm/test/Transforms/InstCombine/icmp-gep.ll
index 887cf1162319bc..776716fe908733 100644
--- a/llvm/test/Transforms/InstCombine/icmp-gep.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-gep.ll
@@ -583,9 +583,7 @@ define i1 @gep_nusw(ptr %p, i64 %a, i64 %b, i64 %c, i64 %d) {
 
 define i1 @pointer_icmp_aligned_with_offset(ptr align 8 %a, ptr align 8 %a2) {
 ; CHECK-LABEL: @pointer_icmp_aligned_with_offset(
-; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i8, ptr [[A:%.*]], i64 4
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq ptr [[GEP]], [[A2:%.*]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 false
 ;
   %gep = getelementptr i8, ptr %a, i64 4
   %cmp = icmp eq ptr %gep, %a2
diff --git a/llvm/test/Transforms/InstCombine/mul-inseltpoison.ll b/llvm/test/Transforms/InstCombine/mul-inseltpoison.ll
index 997758af62a543..8baf6a70fdd5d1 100644
--- a/llvm/test/Transforms/InstCombine/mul-inseltpoison.ll
+++ b/llvm/test/Transforms/InstCombine/mul-inseltpoison.ll
@@ -570,12 +570,12 @@ define i64 @test30(i32 %A, i32 %B) {
 @PR22087 = external global i32
 define i32 @test31(i32 %V) {
 ; CHECK-LABEL: @test31(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
 ; CHECK-NEXT:    [[EXT:%.*]] = zext i1 [[CMP]] to i32
 ; CHECK-NEXT:    [[MUL1:%.*]] = shl i32 [[V:%.*]], [[EXT]]
 ; CHECK-NEXT:    ret i32 [[MUL1]]
 ;
-  %cmp = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
+  %cmp = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
   %ext = zext i1 %cmp to i32
   %shl = shl i32 1, %ext
   %mul = mul i32 %V, %shl
diff --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll
index e38ab1b9622b2c..340828a8d3f9dd 100644
--- a/llvm/test/Transforms/InstCombine/mul.ll
+++ b/llvm/test/Transforms/InstCombine/mul.ll
@@ -1152,12 +1152,12 @@ define i64 @test30(i32 %A, i32 %B) {
 @PR22087 = external global i32
 define i32 @test31(i32 %V) {
 ; CHECK-LABEL: @test31(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
 ; CHECK-NEXT:    [[EXT:%.*]] = zext i1 [[CMP]] to i32
 ; CHECK-NEXT:    [[MUL1:%.*]] = shl i32 [[V:%.*]], [[EXT]]
 ; CHECK-NEXT:    ret i32 [[MUL1]]
 ;
-  %cmp = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
+  %cmp = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
   %ext = zext i1 %cmp to i32
   %shl = shl i32 1, %ext
   %mul = mul i32 %V, %shl
diff --git a/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll b/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll
index 070a3b03302124..e95955da1b8728 100644
--- a/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll
+++ b/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll
@@ -669,7 +669,7 @@ define <2 x i1> @n38_overshift(<2 x i32> %x, <2 x i32> %y) {
 }
 
 ; As usual, don't crash given constantexpr's :/
-@f.a = internal global i16 0
+@f.a = internal global i16 0, align 1
 define i1 @constantexpr() {
 ; CHECK-LABEL: @constantexpr(
 ; CHECK-NEXT:  entry:

@nikic
Copy link
Contributor Author

nikic commented Nov 12, 2024

This may have a bit of compile-time impact (it's close to the noise threshold): https://llvm-compile-time-tracker.com/compare.php?from=36f21eedcfd06ff97ead9625adbf6d8153edd233&to=a371228364bfdec8864e4a3af655d45c8700a87f&stat=instructions:u

If preferred I could only use KnownBits::eq and leave the rest of the code alone. I figured it would be nice to consolidate the handling though.

@nikic
Copy link
Contributor Author

nikic commented Nov 12, 2024

The reason this only makes a different for pointers is that simplifyICmpInst calls isKnownNonEqual, which does a KnownBits based check for integers, but not pointers.

So an alternative fix would be to adjust isKnownNonEqual() to handle this for pointers as well.

@goldsteinn
Copy link
Contributor

The reason this only makes a different for pointers is that simplifyICmpInst calls isKnownNonEqual, which does a KnownBits based check for integers, but not pointers.

So an alternative fix would be to adjust isKnownNonEqual() to handle this for pointers as well.

The two don't seem mutually exclusive. This patch has value IMO even if it's just dropping some bespoke logic for better tested/implemented apis

@nikic
Copy link
Contributor Author

nikic commented Nov 12, 2024

The reason this only makes a different for pointers is that simplifyICmpInst calls isKnownNonEqual, which does a KnownBits based check for integers, but not pointers.
So an alternative fix would be to adjust isKnownNonEqual() to handle this for pointers as well.

The two don't seem mutually exclusive. This patch has value IMO even if it's just dropping some bespoke logic for better tested/implemented apis

Generally agree. However, I just tested the isKnownNonEqual change, and that one has more significant compile-time impact: http://llvm-compile-time-tracker.com/compare.php?from=6d8d9fc8d279623cca94b2b875a92517ed308f18&to=49f63dc5015abe32735e99c39cf86bab32a7a19a&stat=instructions:u It's adding 0.2% to clang thin link.

So I'm inclined to go with this patch only.

Inside foldICmpUsingKnownBits(), instead of rolling our own
logic based on min/max values, make use of KnownBits::eq() etc.
This gives better results for the equality predicates.

I've adjusted some tests to prevent the new fold from triggering,
to retain their original intent of testing constant expressions.
@nikic nikic force-pushed the instcombine-icmp-known-bits branch from bd61adb to 136425b Compare November 12, 2024 16:57
@goldsteinn
Copy link
Contributor

The reason this only makes a different for pointers is that simplifyICmpInst calls isKnownNonEqual, which does a KnownBits based check for integers, but not pointers.
So an alternative fix would be to adjust isKnownNonEqual() to handle this for pointers as well.

The two don't seem mutually exclusive. This patch has value IMO even if it's just dropping some bespoke logic for better tested/implemented apis

Generally agree. However, I just tested the isKnownNonEqual change, and that one has more significant compile-time impact: http://llvm-compile-time-tracker.com/compare.php?from=6d8d9fc8d279623cca94b2b875a92517ed308f18&to=49f63dc5015abe32735e99c39cf86bab32a7a19a&stat=instructions:u It's adding 0.2% to clang thin link.

So I'm inclined to go with this patch only.

Fair enough, is there no other independent value to IsKnownNonEqual for ptrs?

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

@nikic nikic merged commit 78f7ca0 into llvm:main Nov 14, 2024
6 of 8 checks passed
@nikic nikic deleted the instcombine-icmp-known-bits branch November 14, 2024 09:13
akshayrdeodhar pushed a commit to akshayrdeodhar/llvm-project that referenced this pull request Nov 18, 2024
Inside foldICmpUsingKnownBits(), instead of rolling our own logic based
on min/max values, make use of ICmpInst::compare() working on KnownBits.
This gives better results for the equality predicates. In practice, the
improvement is only for pointers, because isKnownNonEqual() handles the
non-pointer case.

I've adjusted some tests to prevent the new fold from triggering, to
retain their original intent of testing constant expressions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants