Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Simplify select using KnownBits of condition #95923

Merged
merged 4 commits into from
Jul 1, 2024

Conversation

nikic
Copy link
Contributor

@nikic nikic commented Jun 18, 2024

Simplify the arms of a select based on the KnownBits implied by its condition. For now this only handles the case where the select arm folds to a constant, but this can be generalized to handle other patterns by using SimplifyDemandedBits instead (in that case we would also have to limit to non-undef conditions).

This is implemented by adding a new member to SimplifyQuery that can be used to inject an additional condition. The affected values are pre-computed and we don't call computeKnownBits() if the select arms don't contain affected values. This reduces the cost in some pathological cases.

@llvmbot
Copy link
Member

llvmbot commented Jun 18, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Nikita Popov (nikic)

Changes

Simplify the arms of a select based on the KnownBits implied by its condition. For now this only handles the case where the select arm folds to a constant, but this can be generalized to handle other patterns by using SimplifyDemandedBits instead (in that case we would also have to limit to non-undef conditions).

This has some compile-time overhead: http://llvm-compile-time-tracker.com/compare.php?from=3ca17443ef4af21bdb1f3b4fbcfff672cbc6176c&to=e05bb77219667121d5042a4d017b15488d55d22b&stat=instructions:u Unfortunately, the majority of the overhead seems to come simply from adding an extra member to SimplifyQuery and not the actual analysis...


Patch is 26.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95923.diff

7 Files Affected:

  • (modified) llvm/include/llvm/Analysis/SimplifyQuery.h (+17)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+4)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+25)
  • (modified) llvm/test/Transforms/InstCombine/select-binop-cmp.ll (+1-4)
  • (modified) llvm/test/Transforms/InstCombine/select-of-bittest.ll (+3-5)
  • (modified) llvm/test/Transforms/InstCombine/select.ll (+6-16)
  • (modified) llvm/test/Transforms/LoopVectorize/interleaved-accesses-pred-stores.ll (+79-85)
diff --git a/llvm/include/llvm/Analysis/SimplifyQuery.h b/llvm/include/llvm/Analysis/SimplifyQuery.h
index 25b8f9b5eaf10..a560744f01222 100644
--- a/llvm/include/llvm/Analysis/SimplifyQuery.h
+++ b/llvm/include/llvm/Analysis/SimplifyQuery.h
@@ -9,6 +9,7 @@
 #ifndef LLVM_ANALYSIS_SIMPLIFYQUERY_H
 #define LLVM_ANALYSIS_SIMPLIFYQUERY_H
 
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/IR/Operator.h"
 
 namespace llvm {
@@ -57,6 +58,15 @@ struct InstrInfoQuery {
   }
 };
 
+/// Evaluate query assuming this condition holds.
+struct CondContext {
+  Value *Cond;
+  bool Invert = false;
+  SmallPtrSet<Value *, 4> AffectedValues;
+
+  CondContext(Value *Cond) : Cond(Cond) {}
+};
+
 struct SimplifyQuery {
   const DataLayout &DL;
   const TargetLibraryInfo *TLI = nullptr;
@@ -64,6 +74,7 @@ struct SimplifyQuery {
   AssumptionCache *AC = nullptr;
   const Instruction *CxtI = nullptr;
   const DomConditionCache *DC = nullptr;
+  const CondContext *CC = nullptr;
 
   // Wrapper to query additional information for instructions like metadata or
   // keywords like nsw, which provides conservative results if those cannot
@@ -113,6 +124,12 @@ struct SimplifyQuery {
     Copy.DC = nullptr;
     return Copy;
   }
+
+  SimplifyQuery getWithCondContext(const CondContext &CC) const {
+    SimplifyQuery Copy(*this);
+    Copy.CC = &CC;
+    return Copy;
+  }
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 8126d2a1acc27..a4a6782bce6dc 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -771,6 +771,10 @@ static void computeKnownBitsFromCond(const Value *V, Value *Cond,
 
 void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
                                        unsigned Depth, const SimplifyQuery &Q) {
+  // Handle injected condition.
+  if (Q.CC && Q.CC->AffectedValues.contains(V))
+    computeKnownBitsFromCond(V, Q.CC->Cond, Known, Depth, Q, Q.CC->Invert);
+
   if (!Q.CxtI)
     return;
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 960c5a29569e3..626fca92c4ff9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -4018,5 +4018,30 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
   if (CondVal->getType() == SI.getType() && isKnownInversion(FalseVal, TrueVal))
     return BinaryOperator::CreateXor(CondVal, FalseVal);
 
+  if (SelType->isIntOrIntVectorTy()) {
+    // Try to simplify select arms based on KnownBits implied by the condition.
+    CondContext CC(CondVal);
+    findValuesAffectedByCondition(CondVal, /*IsAssume=*/false, [&](Value *V) {
+      CC.AffectedValues.insert(V);
+    });
+    SimplifyQuery Q = SQ.getWithInstruction(&SI).getWithCondContext(CC);
+    if (!CC.AffectedValues.empty()) {
+      if (!isa<Constant>(TrueVal)) {
+        KnownBits Known = llvm::computeKnownBits(TrueVal, /*Depth=*/0, Q);
+        if (Known.isConstant())
+          return replaceOperand(SI, 1,
+                                ConstantInt::get(SelType, Known.getConstant()));
+      }
+
+      CC.Invert = true;
+      if (!isa<Constant>(FalseVal)) {
+        KnownBits Known = llvm::computeKnownBits(FalseVal, /*Depth=*/0, Q);
+        if (Known.isConstant())
+          return replaceOperand(SI, 2,
+                                ConstantInt::get(SelType, Known.getConstant()));
+      }
+    }
+  }
+
   return nullptr;
 }
diff --git a/llvm/test/Transforms/InstCombine/select-binop-cmp.ll b/llvm/test/Transforms/InstCombine/select-binop-cmp.ll
index 1fa0c09a9e987..9ee2bc57c3b87 100644
--- a/llvm/test/Transforms/InstCombine/select-binop-cmp.ll
+++ b/llvm/test/Transforms/InstCombine/select-binop-cmp.ll
@@ -571,10 +571,7 @@ define <2 x i8> @select_xor_icmp_vec_bad(<2 x i8> %x, <2 x i8> %y, <2 x i8> %z)
 
 define <2 x i32> @vec_select_no_equivalence(<2 x i32> %x) {
 ; CHECK-LABEL: @vec_select_no_equivalence(
-; CHECK-NEXT:    [[X10:%.*]] = shufflevector <2 x i32> [[X:%.*]], <2 x i32> poison, <2 x i32> <i32 1, i32 0>
-; CHECK-NEXT:    [[COND:%.*]] = icmp eq <2 x i32> [[X]], zeroinitializer
-; CHECK-NEXT:    [[S:%.*]] = select <2 x i1> [[COND]], <2 x i32> [[X10]], <2 x i32> [[X]]
-; CHECK-NEXT:    ret <2 x i32> [[S]]
+; CHECK-NEXT:    ret <2 x i32> [[X:%.*]]
 ;
   %x10 = shufflevector <2 x i32> %x, <2 x i32> undef, <2 x i32> <i32 1, i32 0>
   %cond = icmp eq <2 x i32> %x, zeroinitializer
diff --git a/llvm/test/Transforms/InstCombine/select-of-bittest.ll b/llvm/test/Transforms/InstCombine/select-of-bittest.ll
index e3eb76de459e2..50d3c87f199c3 100644
--- a/llvm/test/Transforms/InstCombine/select-of-bittest.ll
+++ b/llvm/test/Transforms/InstCombine/select-of-bittest.ll
@@ -588,11 +588,9 @@ define i32 @n4(i32 %arg) {
 
 define i32 @n5(i32 %arg) {
 ; CHECK-LABEL: @n5(
-; CHECK-NEXT:    [[T:%.*]] = and i32 [[ARG:%.*]], 2
-; CHECK-NEXT:    [[T1:%.*]] = icmp eq i32 [[T]], 0
-; CHECK-NEXT:    [[T2:%.*]] = and i32 [[ARG]], 2
-; CHECK-NEXT:    [[T3:%.*]] = select i1 [[T1]], i32 [[T2]], i32 1
-; CHECK-NEXT:    ret i32 [[T3]]
+; CHECK-NEXT:    [[T:%.*]] = lshr i32 [[ARG:%.*]], 1
+; CHECK-NEXT:    [[T_LOBIT:%.*]] = and i32 [[T]], 1
+; CHECK-NEXT:    ret i32 [[T_LOBIT]]
 ;
   %t = and i32 %arg, 2
   %t1 = icmp eq i32 %t, 0
diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index b37e9175b26a5..192d7a9629466 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -3807,9 +3807,8 @@ define i32 @src_and_eq_neg1_or_xor(i32 %x, i32 %y) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[AND:%.*]] = and i32 [[Y:%.*]], [[X:%.*]]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[AND]], -1
-; CHECK-NEXT:    [[OR:%.*]] = or i32 [[Y]], [[X]]
 ; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y]], [[X]]
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 [[OR]], i32 [[XOR]]
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 -1, i32 [[XOR]]
 ; CHECK-NEXT:    ret i32 [[COND]]
 ;
 entry:
@@ -3827,9 +3826,8 @@ define i32 @src_and_eq_neg1_xor_or(i32 %x, i32 %y) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[AND:%.*]] = and i32 [[Y:%.*]], [[X:%.*]]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[AND]], -1
-; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y]], [[X]]
 ; CHECK-NEXT:    [[OR:%.*]] = or i32 [[Y]], [[X]]
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 [[XOR]], i32 [[OR]]
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[OR]]
 ; CHECK-NEXT:    ret i32 [[COND]]
 ;
 entry:
@@ -3942,9 +3940,8 @@ define i32 @src_or_eq_0_and_xor(i32 %x, i32 %y) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[OR:%.*]] = or i32 [[Y:%.*]], [[X:%.*]]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[OR]], 0
-; CHECK-NEXT:    [[AND:%.*]] = and i32 [[Y]], [[X]]
 ; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y]], [[X]]
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 [[AND]], i32 [[XOR]]
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[XOR]]
 ; CHECK-NEXT:    ret i32 [[COND]]
 ;
 entry:
@@ -3962,9 +3959,8 @@ define i32 @src_or_eq_0_xor_and(i32 %x, i32 %y) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[OR:%.*]] = or i32 [[Y:%.*]], [[X:%.*]]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[OR]], 0
-; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y]], [[X]]
 ; CHECK-NEXT:    [[AND:%.*]] = and i32 [[Y]], [[X]]
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 [[XOR]], i32 [[AND]]
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[AND]]
 ; CHECK-NEXT:    ret i32 [[COND]]
 ;
 entry:
@@ -4474,10 +4470,7 @@ define i32 @src_no_trans_select_or_eq0_or_xor(i32 %x, i32 %y) {
 define i32 @src_no_trans_select_or_eq0_and_or(i32 %x, i32 %y) {
 ; CHECK-LABEL: @src_no_trans_select_or_eq0_and_or(
 ; CHECK-NEXT:    [[OR:%.*]] = or i32 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[OR0:%.*]] = icmp eq i32 [[OR]], 0
-; CHECK-NEXT:    [[AND:%.*]] = and i32 [[X]], [[Y]]
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[OR0]], i32 [[AND]], i32 [[OR]]
-; CHECK-NEXT:    ret i32 [[COND]]
+; CHECK-NEXT:    ret i32 [[OR]]
 ;
   %or = or i32 %x, %y
   %or0 = icmp eq i32 %or, 0
@@ -4489,10 +4482,7 @@ define i32 @src_no_trans_select_or_eq0_and_or(i32 %x, i32 %y) {
 define i32 @src_no_trans_select_or_eq0_xor_or(i32 %x, i32 %y) {
 ; CHECK-LABEL: @src_no_trans_select_or_eq0_xor_or(
 ; CHECK-NEXT:    [[OR:%.*]] = or i32 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[OR0:%.*]] = icmp eq i32 [[OR]], 0
-; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[X]], [[Y]]
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[OR0]], i32 [[XOR]], i32 [[OR]]
-; CHECK-NEXT:    ret i32 [[COND]]
+; CHECK-NEXT:    ret i32 [[OR]]
 ;
   %or = or i32 %x, %y
   %or0 = icmp eq i32 %or, 0
diff --git a/llvm/test/Transforms/LoopVectorize/interleaved-accesses-pred-stores.ll b/llvm/test/Transforms/LoopVectorize/interleaved-accesses-pred-stores.ll
index 86ca1222aa4ea..fb4545cb715bb 100644
--- a/llvm/test/Transforms/LoopVectorize/interleaved-accesses-pred-stores.ll
+++ b/llvm/test/Transforms/LoopVectorize/interleaved-accesses-pred-stores.ll
@@ -20,39 +20,37 @@ define void @interleaved_with_cond_store_0(ptr %p, i64 %x, i64 %n) {
 ; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp slt i64 [[N:%.*]], 3
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_MOD_VF:%.*]] = and i64 [[N]], 1
-; CHECK-NEXT:    [[TMP0:%.*]] = icmp eq i64 [[N_MOD_VF]], 0
-; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[TMP0]], i64 2, i64 [[N_MOD_VF]]
-; CHECK-NEXT:    [[N_VEC:%.*]] = sub nsw i64 [[N]], [[TMP1]]
+; CHECK-NEXT:    [[DOTNEG:%.*]] = or i64 [[N]], -2
+; CHECK-NEXT:    [[N_VEC:%.*]] = add nsw i64 [[DOTNEG]], [[N]]
 ; CHECK-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <2 x i64> poison, i64 [[X:%.*]], i64 0
 ; CHECK-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <2 x i64> [[BROADCAST_SPLATINSERT]], <2 x i64> poison, <2 x i32> zeroinitializer
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[PRED_STORE_CONTINUE2:%.*]] ]
-; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds [[PAIR:%.*]], ptr [[P:%.*]], i64 [[INDEX]], i32 1
-; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP2]], align 8
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds [[PAIR:%.*]], ptr [[P:%.*]], i64 [[INDEX]], i32 1
+; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP0]], align 8
 ; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> <i32 0, i32 2>
-; CHECK-NEXT:    [[TMP3:%.*]] = icmp eq <2 x i64> [[STRIDED_VEC]], [[BROADCAST_SPLAT]]
-; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x i1> [[TMP3]], i64 0
-; CHECK-NEXT:    br i1 [[TMP4]], label [[PRED_STORE_IF:%.*]], label [[PRED_STORE_CONTINUE:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq <2 x i64> [[STRIDED_VEC]], [[BROADCAST_SPLAT]]
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <2 x i1> [[TMP1]], i64 0
+; CHECK-NEXT:    br i1 [[TMP2]], label [[PRED_STORE_IF:%.*]], label [[PRED_STORE_CONTINUE:%.*]]
 ; CHECK:       pred.store.if:
-; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[INDEX]], i32 1
-; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <4 x i64> [[WIDE_VEC]], i64 0
-; CHECK-NEXT:    store i64 [[TMP6]], ptr [[TMP5]], align 8
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[INDEX]], i32 1
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <4 x i64> [[WIDE_VEC]], i64 0
+; CHECK-NEXT:    store i64 [[TMP4]], ptr [[TMP3]], align 8
 ; CHECK-NEXT:    br label [[PRED_STORE_CONTINUE]]
 ; CHECK:       pred.store.continue:
-; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <2 x i1> [[TMP3]], i64 1
-; CHECK-NEXT:    br i1 [[TMP7]], label [[PRED_STORE_IF1:%.*]], label [[PRED_STORE_CONTINUE2]]
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x i1> [[TMP1]], i64 1
+; CHECK-NEXT:    br i1 [[TMP5]], label [[PRED_STORE_IF1:%.*]], label [[PRED_STORE_CONTINUE2]]
 ; CHECK:       pred.store.if1:
-; CHECK-NEXT:    [[TMP8:%.*]] = or disjoint i64 [[INDEX]], 1
-; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[TMP8]], i32 1
-; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <4 x i64> [[WIDE_VEC]], i64 2
-; CHECK-NEXT:    store i64 [[TMP10]], ptr [[TMP9]], align 8
+; CHECK-NEXT:    [[TMP6:%.*]] = or disjoint i64 [[INDEX]], 1
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[TMP6]], i32 1
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <4 x i64> [[WIDE_VEC]], i64 2
+; CHECK-NEXT:    store i64 [[TMP8]], ptr [[TMP7]], align 8
 ; CHECK-NEXT:    br label [[PRED_STORE_CONTINUE2]]
 ; CHECK:       pred.store.continue2:
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
-; CHECK-NEXT:    [[TMP11:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP11]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    br label [[SCALAR_PH]]
 ; CHECK:       scalar.ph:
@@ -61,11 +59,11 @@ define void @interleaved_with_cond_store_0(ptr %p, i64 %x, i64 %n) {
 ; CHECK:       for.body:
 ; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_NEXT:%.*]], [[IF_MERGE:%.*]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
 ; CHECK-NEXT:    [[P_1:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[I]], i32 1
-; CHECK-NEXT:    [[TMP12:%.*]] = load i64, ptr [[P_1]], align 8
-; CHECK-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[TMP12]], [[X]]
-; CHECK-NEXT:    br i1 [[TMP13]], label [[IF_THEN:%.*]], label [[IF_MERGE]]
+; CHECK-NEXT:    [[TMP10:%.*]] = load i64, ptr [[P_1]], align 8
+; CHECK-NEXT:    [[TMP11:%.*]] = icmp eq i64 [[TMP10]], [[X]]
+; CHECK-NEXT:    br i1 [[TMP11]], label [[IF_THEN:%.*]], label [[IF_MERGE]]
 ; CHECK:       if.then:
-; CHECK-NEXT:    store i64 [[TMP12]], ptr [[P_1]], align 8
+; CHECK-NEXT:    store i64 [[TMP10]], ptr [[P_1]], align 8
 ; CHECK-NEXT:    br label [[IF_MERGE]]
 ; CHECK:       if.merge:
 ; CHECK-NEXT:    [[I_NEXT]] = add nuw nsw i64 [[I]], 1
@@ -114,46 +112,44 @@ define void @interleaved_with_cond_store_1(ptr %p, i64 %x, i64 %n) {
 ; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp slt i64 [[N:%.*]], 3
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_MOD_VF:%.*]] = and i64 [[N]], 1
-; CHECK-NEXT:    [[TMP0:%.*]] = icmp eq i64 [[N_MOD_VF]], 0
-; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[TMP0]], i64 2, i64 [[N_MOD_VF]]
-; CHECK-NEXT:    [[N_VEC:%.*]] = sub nsw i64 [[N]], [[TMP1]]
+; CHECK-NEXT:    [[DOTNEG:%.*]] = or i64 [[N]], -2
+; CHECK-NEXT:    [[N_VEC:%.*]] = add nsw i64 [[DOTNEG]], [[N]]
 ; CHECK-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <2 x i64> poison, i64 [[X:%.*]], i64 0
 ; CHECK-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <2 x i64> [[BROADCAST_SPLATINSERT]], <2 x i64> poison, <2 x i32> zeroinitializer
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[PRED_STORE_CONTINUE2:%.*]] ]
-; CHECK-NEXT:    [[TMP2:%.*]] = or disjoint i64 [[INDEX]], 1
-; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds [[PAIR:%.*]], ptr [[P:%.*]], i64 [[INDEX]], i32 0
-; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[INDEX]], i32 1
-; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[TMP2]], i32 1
-; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP4]], align 8
+; CHECK-NEXT:    [[TMP0:%.*]] = or disjoint i64 [[INDEX]], 1
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds [[PAIR:%.*]], ptr [[P:%.*]], i64 [[INDEX]], i32 0
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[INDEX]], i32 1
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[TMP0]], i32 1
+; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP2]], align 8
 ; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> <i32 0, i32 2>
-; CHECK-NEXT:    [[TMP6:%.*]] = icmp eq <2 x i64> [[STRIDED_VEC]], [[BROADCAST_SPLAT]]
-; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <2 x i1> [[TMP6]], i64 0
-; CHECK-NEXT:    br i1 [[TMP7]], label [[PRED_STORE_IF:%.*]], label [[PRED_STORE_CONTINUE:%.*]]
+; CHECK-NEXT:    [[TMP4:%.*]] = icmp eq <2 x i64> [[STRIDED_VEC]], [[BROADCAST_SPLAT]]
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x i1> [[TMP4]], i64 0
+; CHECK-NEXT:    br i1 [[TMP5]], label [[PRED_STORE_IF:%.*]], label [[PRED_STORE_CONTINUE:%.*]]
 ; CHECK:       pred.store.if:
-; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[INDEX]], i32 0
-; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <4 x i64> [[WIDE_VEC]], i64 0
-; CHECK-NEXT:    store i64 [[TMP9]], ptr [[TMP8]], align 8
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[INDEX]], i32 0
+; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <4 x i64> [[WIDE_VEC]], i64 0
+; CHECK-NEXT:    store i64 [[TMP7]], ptr [[TMP6]], align 8
 ; CHECK-NEXT:    br label [[PRED_STORE_CONTINUE]]
 ; CHECK:       pred.store.continue:
-; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <2 x i1> [[TMP6]], i64 1
-; CHECK-NEXT:    br i1 [[TMP10]], label [[PRED_STORE_IF1:%.*]], label [[PRED_STORE_CONTINUE2]]
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <2 x i1> [[TMP4]], i64 1
+; CHECK-NEXT:    br i1 [[TMP8]], label [[PRED_STORE_IF1:%.*]], label [[PRED_STORE_CONTINUE2]]
 ; CHECK:       pred.store.if1:
-; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[TMP2]], i32 0
-; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <4 x i64> [[WIDE_VEC]], i64 2
-; CHECK-NEXT:    store i64 [[TMP12]], ptr [[TMP11]], align 8
+; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[TMP0]], i32 0
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <4 x i64> [[WIDE_VEC]], i64 2
+; CHECK-NEXT:    store i64 [[TMP10]], ptr [[TMP9]], align 8
 ; CHECK-NEXT:    br label [[PRED_STORE_CONTINUE2]]
 ; CHECK:       pred.store.continue2:
-; CHECK-NEXT:    [[WIDE_VEC3:%.*]] = load <4 x i64>, ptr [[TMP3]], align 8
-; CHECK-NEXT:    [[TMP13:%.*]] = extractelement <4 x i64> [[WIDE_VEC3]], i64 0
-; CHECK-NEXT:    store i64 [[TMP13]], ptr [[TMP4]], align 8
-; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <4 x i64> [[WIDE_VEC3]], i64 2
-; CHECK-NEXT:    store i64 [[TMP14]], ptr [[TMP5]], align 8
+; CHECK-NEXT:    [[WIDE_VEC3:%.*]] = load <4 x i64>, ptr [[TMP1]], align 8
+; CHECK-NEXT:    [[TMP11:%.*]] = extractelement <4 x i64> [[WIDE_VEC3]], i64 0
+; CHECK-NEXT:    store i64 [[TMP11]], ptr [[TMP2]], align 8
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <4 x i64> [[WIDE_VEC3]], i64 2
+; CHECK-NEXT:    store i64 [[TMP12]], ptr [[TMP3]], align 8
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
-; CHECK-NEXT:    [[TMP15:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP15]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
+; CHECK-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    br label [[SCALAR_PH]]
 ; CHECK:       scalar.ph:
@@ -163,15 +159,15 @@ define void @interleaved_with_cond_store_1(ptr %p, i64 %x, i64 %n) {
 ; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_NEXT:%.*]], [[IF_MERGE:%.*]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
 ; CHECK-NEXT:    [[P_0:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[I]], i32 0
 ; CHECK-NEXT:    [[P_1:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[I]], i32 1
-; CHECK-NEXT:    [[TMP16:%.*]] = load i64, ptr [[P_1]], align 8
-; CHECK-NEXT:    [[TMP17:%.*]] = icmp eq i64 [[TMP16]], [[X]]
-; CHECK-NEXT:    br i1 [[TMP17]], label [[IF_THEN:%.*]], label [[IF_MERGE]]
+; CHECK-NEXT:    [[TMP14:%.*]] = load i64, ptr [[P_1]], align 8
+; CHECK-NEXT:    [[TMP15:%.*]] = ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 18, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Nikita Popov (nikic)

Changes

Simplify the arms of a select based on the KnownBits implied by its condition. For now this only handles the case where the select arm folds to a constant, but this can be generalized to handle other patterns by using SimplifyDemandedBits instead (in that case we would also have to limit to non-undef conditions).

This has some compile-time overhead: http://llvm-compile-time-tracker.com/compare.php?from=3ca17443ef4af21bdb1f3b4fbcfff672cbc6176c&amp;to=e05bb77219667121d5042a4d017b15488d55d22b&amp;stat=instructions:u Unfortunately, the majority of the overhead seems to come simply from adding an extra member to SimplifyQuery and not the actual analysis...


Patch is 26.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95923.diff

7 Files Affected:

  • (modified) llvm/include/llvm/Analysis/SimplifyQuery.h (+17)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+4)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+25)
  • (modified) llvm/test/Transforms/InstCombine/select-binop-cmp.ll (+1-4)
  • (modified) llvm/test/Transforms/InstCombine/select-of-bittest.ll (+3-5)
  • (modified) llvm/test/Transforms/InstCombine/select.ll (+6-16)
  • (modified) llvm/test/Transforms/LoopVectorize/interleaved-accesses-pred-stores.ll (+79-85)
diff --git a/llvm/include/llvm/Analysis/SimplifyQuery.h b/llvm/include/llvm/Analysis/SimplifyQuery.h
index 25b8f9b5eaf10..a560744f01222 100644
--- a/llvm/include/llvm/Analysis/SimplifyQuery.h
+++ b/llvm/include/llvm/Analysis/SimplifyQuery.h
@@ -9,6 +9,7 @@
 #ifndef LLVM_ANALYSIS_SIMPLIFYQUERY_H
 #define LLVM_ANALYSIS_SIMPLIFYQUERY_H
 
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/IR/Operator.h"
 
 namespace llvm {
@@ -57,6 +58,15 @@ struct InstrInfoQuery {
   }
 };
 
+/// Evaluate query assuming this condition holds.
+struct CondContext {
+  Value *Cond;
+  bool Invert = false;
+  SmallPtrSet<Value *, 4> AffectedValues;
+
+  CondContext(Value *Cond) : Cond(Cond) {}
+};
+
 struct SimplifyQuery {
   const DataLayout &DL;
   const TargetLibraryInfo *TLI = nullptr;
@@ -64,6 +74,7 @@ struct SimplifyQuery {
   AssumptionCache *AC = nullptr;
   const Instruction *CxtI = nullptr;
   const DomConditionCache *DC = nullptr;
+  const CondContext *CC = nullptr;
 
   // Wrapper to query additional information for instructions like metadata or
   // keywords like nsw, which provides conservative results if those cannot
@@ -113,6 +124,12 @@ struct SimplifyQuery {
     Copy.DC = nullptr;
     return Copy;
   }
+
+  SimplifyQuery getWithCondContext(const CondContext &CC) const {
+    SimplifyQuery Copy(*this);
+    Copy.CC = &CC;
+    return Copy;
+  }
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 8126d2a1acc27..a4a6782bce6dc 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -771,6 +771,10 @@ static void computeKnownBitsFromCond(const Value *V, Value *Cond,
 
 void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
                                        unsigned Depth, const SimplifyQuery &Q) {
+  // Handle injected condition.
+  if (Q.CC && Q.CC->AffectedValues.contains(V))
+    computeKnownBitsFromCond(V, Q.CC->Cond, Known, Depth, Q, Q.CC->Invert);
+
   if (!Q.CxtI)
     return;
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 960c5a29569e3..626fca92c4ff9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -4018,5 +4018,30 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
   if (CondVal->getType() == SI.getType() && isKnownInversion(FalseVal, TrueVal))
     return BinaryOperator::CreateXor(CondVal, FalseVal);
 
+  if (SelType->isIntOrIntVectorTy()) {
+    // Try to simplify select arms based on KnownBits implied by the condition.
+    CondContext CC(CondVal);
+    findValuesAffectedByCondition(CondVal, /*IsAssume=*/false, [&](Value *V) {
+      CC.AffectedValues.insert(V);
+    });
+    SimplifyQuery Q = SQ.getWithInstruction(&SI).getWithCondContext(CC);
+    if (!CC.AffectedValues.empty()) {
+      if (!isa<Constant>(TrueVal)) {
+        KnownBits Known = llvm::computeKnownBits(TrueVal, /*Depth=*/0, Q);
+        if (Known.isConstant())
+          return replaceOperand(SI, 1,
+                                ConstantInt::get(SelType, Known.getConstant()));
+      }
+
+      CC.Invert = true;
+      if (!isa<Constant>(FalseVal)) {
+        KnownBits Known = llvm::computeKnownBits(FalseVal, /*Depth=*/0, Q);
+        if (Known.isConstant())
+          return replaceOperand(SI, 2,
+                                ConstantInt::get(SelType, Known.getConstant()));
+      }
+    }
+  }
+
   return nullptr;
 }
diff --git a/llvm/test/Transforms/InstCombine/select-binop-cmp.ll b/llvm/test/Transforms/InstCombine/select-binop-cmp.ll
index 1fa0c09a9e987..9ee2bc57c3b87 100644
--- a/llvm/test/Transforms/InstCombine/select-binop-cmp.ll
+++ b/llvm/test/Transforms/InstCombine/select-binop-cmp.ll
@@ -571,10 +571,7 @@ define <2 x i8> @select_xor_icmp_vec_bad(<2 x i8> %x, <2 x i8> %y, <2 x i8> %z)
 
 define <2 x i32> @vec_select_no_equivalence(<2 x i32> %x) {
 ; CHECK-LABEL: @vec_select_no_equivalence(
-; CHECK-NEXT:    [[X10:%.*]] = shufflevector <2 x i32> [[X:%.*]], <2 x i32> poison, <2 x i32> <i32 1, i32 0>
-; CHECK-NEXT:    [[COND:%.*]] = icmp eq <2 x i32> [[X]], zeroinitializer
-; CHECK-NEXT:    [[S:%.*]] = select <2 x i1> [[COND]], <2 x i32> [[X10]], <2 x i32> [[X]]
-; CHECK-NEXT:    ret <2 x i32> [[S]]
+; CHECK-NEXT:    ret <2 x i32> [[X:%.*]]
 ;
   %x10 = shufflevector <2 x i32> %x, <2 x i32> undef, <2 x i32> <i32 1, i32 0>
   %cond = icmp eq <2 x i32> %x, zeroinitializer
diff --git a/llvm/test/Transforms/InstCombine/select-of-bittest.ll b/llvm/test/Transforms/InstCombine/select-of-bittest.ll
index e3eb76de459e2..50d3c87f199c3 100644
--- a/llvm/test/Transforms/InstCombine/select-of-bittest.ll
+++ b/llvm/test/Transforms/InstCombine/select-of-bittest.ll
@@ -588,11 +588,9 @@ define i32 @n4(i32 %arg) {
 
 define i32 @n5(i32 %arg) {
 ; CHECK-LABEL: @n5(
-; CHECK-NEXT:    [[T:%.*]] = and i32 [[ARG:%.*]], 2
-; CHECK-NEXT:    [[T1:%.*]] = icmp eq i32 [[T]], 0
-; CHECK-NEXT:    [[T2:%.*]] = and i32 [[ARG]], 2
-; CHECK-NEXT:    [[T3:%.*]] = select i1 [[T1]], i32 [[T2]], i32 1
-; CHECK-NEXT:    ret i32 [[T3]]
+; CHECK-NEXT:    [[T:%.*]] = lshr i32 [[ARG:%.*]], 1
+; CHECK-NEXT:    [[T_LOBIT:%.*]] = and i32 [[T]], 1
+; CHECK-NEXT:    ret i32 [[T_LOBIT]]
 ;
   %t = and i32 %arg, 2
   %t1 = icmp eq i32 %t, 0
diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index b37e9175b26a5..192d7a9629466 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -3807,9 +3807,8 @@ define i32 @src_and_eq_neg1_or_xor(i32 %x, i32 %y) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[AND:%.*]] = and i32 [[Y:%.*]], [[X:%.*]]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[AND]], -1
-; CHECK-NEXT:    [[OR:%.*]] = or i32 [[Y]], [[X]]
 ; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y]], [[X]]
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 [[OR]], i32 [[XOR]]
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 -1, i32 [[XOR]]
 ; CHECK-NEXT:    ret i32 [[COND]]
 ;
 entry:
@@ -3827,9 +3826,8 @@ define i32 @src_and_eq_neg1_xor_or(i32 %x, i32 %y) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[AND:%.*]] = and i32 [[Y:%.*]], [[X:%.*]]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[AND]], -1
-; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y]], [[X]]
 ; CHECK-NEXT:    [[OR:%.*]] = or i32 [[Y]], [[X]]
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 [[XOR]], i32 [[OR]]
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[OR]]
 ; CHECK-NEXT:    ret i32 [[COND]]
 ;
 entry:
@@ -3942,9 +3940,8 @@ define i32 @src_or_eq_0_and_xor(i32 %x, i32 %y) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[OR:%.*]] = or i32 [[Y:%.*]], [[X:%.*]]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[OR]], 0
-; CHECK-NEXT:    [[AND:%.*]] = and i32 [[Y]], [[X]]
 ; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y]], [[X]]
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 [[AND]], i32 [[XOR]]
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[XOR]]
 ; CHECK-NEXT:    ret i32 [[COND]]
 ;
 entry:
@@ -3962,9 +3959,8 @@ define i32 @src_or_eq_0_xor_and(i32 %x, i32 %y) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[OR:%.*]] = or i32 [[Y:%.*]], [[X:%.*]]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[OR]], 0
-; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y]], [[X]]
 ; CHECK-NEXT:    [[AND:%.*]] = and i32 [[Y]], [[X]]
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 [[XOR]], i32 [[AND]]
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[AND]]
 ; CHECK-NEXT:    ret i32 [[COND]]
 ;
 entry:
@@ -4474,10 +4470,7 @@ define i32 @src_no_trans_select_or_eq0_or_xor(i32 %x, i32 %y) {
 define i32 @src_no_trans_select_or_eq0_and_or(i32 %x, i32 %y) {
 ; CHECK-LABEL: @src_no_trans_select_or_eq0_and_or(
 ; CHECK-NEXT:    [[OR:%.*]] = or i32 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[OR0:%.*]] = icmp eq i32 [[OR]], 0
-; CHECK-NEXT:    [[AND:%.*]] = and i32 [[X]], [[Y]]
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[OR0]], i32 [[AND]], i32 [[OR]]
-; CHECK-NEXT:    ret i32 [[COND]]
+; CHECK-NEXT:    ret i32 [[OR]]
 ;
   %or = or i32 %x, %y
   %or0 = icmp eq i32 %or, 0
@@ -4489,10 +4482,7 @@ define i32 @src_no_trans_select_or_eq0_and_or(i32 %x, i32 %y) {
 define i32 @src_no_trans_select_or_eq0_xor_or(i32 %x, i32 %y) {
 ; CHECK-LABEL: @src_no_trans_select_or_eq0_xor_or(
 ; CHECK-NEXT:    [[OR:%.*]] = or i32 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[OR0:%.*]] = icmp eq i32 [[OR]], 0
-; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[X]], [[Y]]
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[OR0]], i32 [[XOR]], i32 [[OR]]
-; CHECK-NEXT:    ret i32 [[COND]]
+; CHECK-NEXT:    ret i32 [[OR]]
 ;
   %or = or i32 %x, %y
   %or0 = icmp eq i32 %or, 0
diff --git a/llvm/test/Transforms/LoopVectorize/interleaved-accesses-pred-stores.ll b/llvm/test/Transforms/LoopVectorize/interleaved-accesses-pred-stores.ll
index 86ca1222aa4ea..fb4545cb715bb 100644
--- a/llvm/test/Transforms/LoopVectorize/interleaved-accesses-pred-stores.ll
+++ b/llvm/test/Transforms/LoopVectorize/interleaved-accesses-pred-stores.ll
@@ -20,39 +20,37 @@ define void @interleaved_with_cond_store_0(ptr %p, i64 %x, i64 %n) {
 ; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp slt i64 [[N:%.*]], 3
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_MOD_VF:%.*]] = and i64 [[N]], 1
-; CHECK-NEXT:    [[TMP0:%.*]] = icmp eq i64 [[N_MOD_VF]], 0
-; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[TMP0]], i64 2, i64 [[N_MOD_VF]]
-; CHECK-NEXT:    [[N_VEC:%.*]] = sub nsw i64 [[N]], [[TMP1]]
+; CHECK-NEXT:    [[DOTNEG:%.*]] = or i64 [[N]], -2
+; CHECK-NEXT:    [[N_VEC:%.*]] = add nsw i64 [[DOTNEG]], [[N]]
 ; CHECK-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <2 x i64> poison, i64 [[X:%.*]], i64 0
 ; CHECK-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <2 x i64> [[BROADCAST_SPLATINSERT]], <2 x i64> poison, <2 x i32> zeroinitializer
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[PRED_STORE_CONTINUE2:%.*]] ]
-; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds [[PAIR:%.*]], ptr [[P:%.*]], i64 [[INDEX]], i32 1
-; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP2]], align 8
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds [[PAIR:%.*]], ptr [[P:%.*]], i64 [[INDEX]], i32 1
+; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP0]], align 8
 ; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> <i32 0, i32 2>
-; CHECK-NEXT:    [[TMP3:%.*]] = icmp eq <2 x i64> [[STRIDED_VEC]], [[BROADCAST_SPLAT]]
-; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x i1> [[TMP3]], i64 0
-; CHECK-NEXT:    br i1 [[TMP4]], label [[PRED_STORE_IF:%.*]], label [[PRED_STORE_CONTINUE:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq <2 x i64> [[STRIDED_VEC]], [[BROADCAST_SPLAT]]
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <2 x i1> [[TMP1]], i64 0
+; CHECK-NEXT:    br i1 [[TMP2]], label [[PRED_STORE_IF:%.*]], label [[PRED_STORE_CONTINUE:%.*]]
 ; CHECK:       pred.store.if:
-; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[INDEX]], i32 1
-; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <4 x i64> [[WIDE_VEC]], i64 0
-; CHECK-NEXT:    store i64 [[TMP6]], ptr [[TMP5]], align 8
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[INDEX]], i32 1
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <4 x i64> [[WIDE_VEC]], i64 0
+; CHECK-NEXT:    store i64 [[TMP4]], ptr [[TMP3]], align 8
 ; CHECK-NEXT:    br label [[PRED_STORE_CONTINUE]]
 ; CHECK:       pred.store.continue:
-; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <2 x i1> [[TMP3]], i64 1
-; CHECK-NEXT:    br i1 [[TMP7]], label [[PRED_STORE_IF1:%.*]], label [[PRED_STORE_CONTINUE2]]
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x i1> [[TMP1]], i64 1
+; CHECK-NEXT:    br i1 [[TMP5]], label [[PRED_STORE_IF1:%.*]], label [[PRED_STORE_CONTINUE2]]
 ; CHECK:       pred.store.if1:
-; CHECK-NEXT:    [[TMP8:%.*]] = or disjoint i64 [[INDEX]], 1
-; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[TMP8]], i32 1
-; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <4 x i64> [[WIDE_VEC]], i64 2
-; CHECK-NEXT:    store i64 [[TMP10]], ptr [[TMP9]], align 8
+; CHECK-NEXT:    [[TMP6:%.*]] = or disjoint i64 [[INDEX]], 1
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[TMP6]], i32 1
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <4 x i64> [[WIDE_VEC]], i64 2
+; CHECK-NEXT:    store i64 [[TMP8]], ptr [[TMP7]], align 8
 ; CHECK-NEXT:    br label [[PRED_STORE_CONTINUE2]]
 ; CHECK:       pred.store.continue2:
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
-; CHECK-NEXT:    [[TMP11:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP11]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    br label [[SCALAR_PH]]
 ; CHECK:       scalar.ph:
@@ -61,11 +59,11 @@ define void @interleaved_with_cond_store_0(ptr %p, i64 %x, i64 %n) {
 ; CHECK:       for.body:
 ; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_NEXT:%.*]], [[IF_MERGE:%.*]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
 ; CHECK-NEXT:    [[P_1:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[I]], i32 1
-; CHECK-NEXT:    [[TMP12:%.*]] = load i64, ptr [[P_1]], align 8
-; CHECK-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[TMP12]], [[X]]
-; CHECK-NEXT:    br i1 [[TMP13]], label [[IF_THEN:%.*]], label [[IF_MERGE]]
+; CHECK-NEXT:    [[TMP10:%.*]] = load i64, ptr [[P_1]], align 8
+; CHECK-NEXT:    [[TMP11:%.*]] = icmp eq i64 [[TMP10]], [[X]]
+; CHECK-NEXT:    br i1 [[TMP11]], label [[IF_THEN:%.*]], label [[IF_MERGE]]
 ; CHECK:       if.then:
-; CHECK-NEXT:    store i64 [[TMP12]], ptr [[P_1]], align 8
+; CHECK-NEXT:    store i64 [[TMP10]], ptr [[P_1]], align 8
 ; CHECK-NEXT:    br label [[IF_MERGE]]
 ; CHECK:       if.merge:
 ; CHECK-NEXT:    [[I_NEXT]] = add nuw nsw i64 [[I]], 1
@@ -114,46 +112,44 @@ define void @interleaved_with_cond_store_1(ptr %p, i64 %x, i64 %n) {
 ; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp slt i64 [[N:%.*]], 3
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_MOD_VF:%.*]] = and i64 [[N]], 1
-; CHECK-NEXT:    [[TMP0:%.*]] = icmp eq i64 [[N_MOD_VF]], 0
-; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[TMP0]], i64 2, i64 [[N_MOD_VF]]
-; CHECK-NEXT:    [[N_VEC:%.*]] = sub nsw i64 [[N]], [[TMP1]]
+; CHECK-NEXT:    [[DOTNEG:%.*]] = or i64 [[N]], -2
+; CHECK-NEXT:    [[N_VEC:%.*]] = add nsw i64 [[DOTNEG]], [[N]]
 ; CHECK-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <2 x i64> poison, i64 [[X:%.*]], i64 0
 ; CHECK-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <2 x i64> [[BROADCAST_SPLATINSERT]], <2 x i64> poison, <2 x i32> zeroinitializer
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[PRED_STORE_CONTINUE2:%.*]] ]
-; CHECK-NEXT:    [[TMP2:%.*]] = or disjoint i64 [[INDEX]], 1
-; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds [[PAIR:%.*]], ptr [[P:%.*]], i64 [[INDEX]], i32 0
-; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[INDEX]], i32 1
-; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[TMP2]], i32 1
-; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP4]], align 8
+; CHECK-NEXT:    [[TMP0:%.*]] = or disjoint i64 [[INDEX]], 1
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds [[PAIR:%.*]], ptr [[P:%.*]], i64 [[INDEX]], i32 0
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[INDEX]], i32 1
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[TMP0]], i32 1
+; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP2]], align 8
 ; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> <i32 0, i32 2>
-; CHECK-NEXT:    [[TMP6:%.*]] = icmp eq <2 x i64> [[STRIDED_VEC]], [[BROADCAST_SPLAT]]
-; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <2 x i1> [[TMP6]], i64 0
-; CHECK-NEXT:    br i1 [[TMP7]], label [[PRED_STORE_IF:%.*]], label [[PRED_STORE_CONTINUE:%.*]]
+; CHECK-NEXT:    [[TMP4:%.*]] = icmp eq <2 x i64> [[STRIDED_VEC]], [[BROADCAST_SPLAT]]
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x i1> [[TMP4]], i64 0
+; CHECK-NEXT:    br i1 [[TMP5]], label [[PRED_STORE_IF:%.*]], label [[PRED_STORE_CONTINUE:%.*]]
 ; CHECK:       pred.store.if:
-; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[INDEX]], i32 0
-; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <4 x i64> [[WIDE_VEC]], i64 0
-; CHECK-NEXT:    store i64 [[TMP9]], ptr [[TMP8]], align 8
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[INDEX]], i32 0
+; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <4 x i64> [[WIDE_VEC]], i64 0
+; CHECK-NEXT:    store i64 [[TMP7]], ptr [[TMP6]], align 8
 ; CHECK-NEXT:    br label [[PRED_STORE_CONTINUE]]
 ; CHECK:       pred.store.continue:
-; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <2 x i1> [[TMP6]], i64 1
-; CHECK-NEXT:    br i1 [[TMP10]], label [[PRED_STORE_IF1:%.*]], label [[PRED_STORE_CONTINUE2]]
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <2 x i1> [[TMP4]], i64 1
+; CHECK-NEXT:    br i1 [[TMP8]], label [[PRED_STORE_IF1:%.*]], label [[PRED_STORE_CONTINUE2]]
 ; CHECK:       pred.store.if1:
-; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[TMP2]], i32 0
-; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <4 x i64> [[WIDE_VEC]], i64 2
-; CHECK-NEXT:    store i64 [[TMP12]], ptr [[TMP11]], align 8
+; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[TMP0]], i32 0
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <4 x i64> [[WIDE_VEC]], i64 2
+; CHECK-NEXT:    store i64 [[TMP10]], ptr [[TMP9]], align 8
 ; CHECK-NEXT:    br label [[PRED_STORE_CONTINUE2]]
 ; CHECK:       pred.store.continue2:
-; CHECK-NEXT:    [[WIDE_VEC3:%.*]] = load <4 x i64>, ptr [[TMP3]], align 8
-; CHECK-NEXT:    [[TMP13:%.*]] = extractelement <4 x i64> [[WIDE_VEC3]], i64 0
-; CHECK-NEXT:    store i64 [[TMP13]], ptr [[TMP4]], align 8
-; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <4 x i64> [[WIDE_VEC3]], i64 2
-; CHECK-NEXT:    store i64 [[TMP14]], ptr [[TMP5]], align 8
+; CHECK-NEXT:    [[WIDE_VEC3:%.*]] = load <4 x i64>, ptr [[TMP1]], align 8
+; CHECK-NEXT:    [[TMP11:%.*]] = extractelement <4 x i64> [[WIDE_VEC3]], i64 0
+; CHECK-NEXT:    store i64 [[TMP11]], ptr [[TMP2]], align 8
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <4 x i64> [[WIDE_VEC3]], i64 2
+; CHECK-NEXT:    store i64 [[TMP12]], ptr [[TMP3]], align 8
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
-; CHECK-NEXT:    [[TMP15:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP15]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
+; CHECK-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    br label [[SCALAR_PH]]
 ; CHECK:       scalar.ph:
@@ -163,15 +159,15 @@ define void @interleaved_with_cond_store_1(ptr %p, i64 %x, i64 %n) {
 ; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_NEXT:%.*]], [[IF_MERGE:%.*]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
 ; CHECK-NEXT:    [[P_0:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[I]], i32 0
 ; CHECK-NEXT:    [[P_1:%.*]] = getelementptr inbounds [[PAIR]], ptr [[P]], i64 [[I]], i32 1
-; CHECK-NEXT:    [[TMP16:%.*]] = load i64, ptr [[P_1]], align 8
-; CHECK-NEXT:    [[TMP17:%.*]] = icmp eq i64 [[TMP16]], [[X]]
-; CHECK-NEXT:    br i1 [[TMP17]], label [[IF_THEN:%.*]], label [[IF_MERGE]]
+; CHECK-NEXT:    [[TMP14:%.*]] = load i64, ptr [[P_1]], align 8
+; CHECK-NEXT:    [[TMP15:%.*]] = ...
[truncated]

CondContext CC(CondVal);
findValuesAffectedByCondition(CondVal, /*IsAssume=*/false, [&](Value *V) {
CC.AffectedValues.insert(V);
});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe aggressive inst combine is the right place for this?
Maybe we could also track select instructions in DomConditionCache (although that would req some more complex logic to keep it updated as we created/delect select instructions).

Also, think for select you can at least limit V to values that are actually used by select arms.
Maybe create a set of uses up to Depth = 6 for each of the arms and only add affected values if they also hit that set? Otherwise can't really imagine this helping.

You might also try early out of !isa<Constant>(TrueVal) && !isa<Constant>(FalseVal) to avoid unnecessary setup.

Finally, maybe try simplifying both True and False arm in one go instead of essentially doubling the number of trips.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, think you really should only add V if its TrueArm/FalseArm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in the PR description, most of the overhead comes from simply adding a member to SimplifyQuery rather than anything happening in InstCombine. These are the numbers just for adding a dummy member: https://llvm-compile-time-tracker.com/compare.php?from=3ca17443ef4af21bdb1f3b4fbcfff672cbc6176c&to=e7a64d837a061d6afeac9c0f06c4827998d43561&stat=instructions:u

As such, I don't think we'll get any substantial improvement out of changing how exactly the KnownBits calculation is done.

Maybe aggressive inst combine is the right place for this?

This would prevent extending this to use SimplifyDemandedBits, which enabled simplifying the expression without folding to constants (nikic@1b8edbd). This will allow us to subsume special cases like #92658.

Maybe we could also track select instructions in DomConditionCache (although that would req some more complex logic to keep it updated as we created/delect select instructions).

I think managing invalidation for this would be quite tricky, and this approach would be inherently limited to one-use chains to the select only.

Actually, think you really should only add V if its TrueArm/FalseArm.

Why? V may be used by TrueArm/FalseArm recursively.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in the PR description, most of the overhead comes from simply adding a member to SimplifyQuery rather than anything happening in InstCombine. These are the numbers just for adding a dummy member: https://llvm-compile-time-tracker.com/compare.php?from=3ca17443ef4af21bdb1f3b4fbcfff672cbc6176c&to=e7a64d837a061d6afeac9c0f06c4827998d43561&stat=instructions:u

As such, I don't think we'll get any substantial improvement out of changing how exactly the KnownBits calculation is done.

I see, although a bit more ugly, what about adding an additional optional argument to computeKnownBits/computeKnownBitsFromContext?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in the PR description, most of the overhead comes from simply adding a member to SimplifyQuery rather than anything happening in InstCombine. These are the numbers just for adding a dummy member: https://llvm-compile-time-tracker.com/compare.php?from=3ca17443ef4af21bdb1f3b4fbcfff672cbc6176c&to=e7a64d837a061d6afeac9c0f06c4827998d43561&stat=instructions:u
As such, I don't think we'll get any substantial improvement out of changing how exactly the KnownBits calculation is done.

I see, although a bit more ugly, what about adding an additional optional argument to computeKnownBits/computeKnownBitsFromContext?

Also, esp for ThinLTO, seems to be substantial enough difference to justify saving an iteration of InstCombine IMO. Don't think it will really affect code complexity much.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get what you refer to with "saving an iteration of InstCombine" -- how can we save an InstCombine iteration?

For the record, this is the impact of the KnownBits calculation (as opposed to the extra member and the affected value calculation): http://llvm-compile-time-tracker.com/compare.php?from=c9edd08d49137e47fe3acda3ea5b2bae563cb3ac&to=db730283063dff0dc72338413f346bd0174dafdd&stat=instructions:u

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get what you refer to with "saving an iteration of InstCombine" -- how can we save an InstCombine iteration?

I mean do the full transform in a single shot as opposed to going through the work list twice.

Iteration was wrong word to us there given it's other meaning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean do the full transform in a single shot as opposed to going through the work list twice.

Do you mean the case where it's possible to simplify both select operands? If so, this should be very rare, and as such not worth optimizing.

@dtcxzyw
Copy link
Member

dtcxzyw commented Jun 19, 2024

Top 5 improvements:
meshlab/MarchingCubes.cpp.ll 678109485 659907628 -2.68%
c3c/tokens.c.ll 145223091 141469919 -2.58%
tev/HelpWindow.cpp.ll 1039474883 1012684821 -2.58%
faiss/HNSW.cpp.ll 3522270697 3438578766 -2.38%
rust-analyzer-rs/4ifo5x52byu175vr.ll 302363407 296249464 -2.02%
Top 5 regressions:
eastl/TestBitset.cpp.ll 7448097780 8128445212 +9.13%
hwloc/bitmap.ll 1263490817 1334914518 +5.65%
abc/rsbDec6.c.ll 1330779608 1379597304 +3.67%
gromacs/lincs.cpp.ll 4224451840 4378129884 +3.64%
z3/util.cpp.ll 141417444 146449050 +3.56%
Overall: 0.04440040%

@nikic
Copy link
Contributor Author

nikic commented Jun 19, 2024

@dtcxzyw What metric does this refer to?

@dtcxzyw
Copy link
Member

dtcxzyw commented Jun 19, 2024

@dtcxzyw What metric does this refer to?

It refers to execution time of opt -O3 (in instruction:u). See also dtcxzyw/llvm-opt-benchmark#680.

@nikic
Copy link
Contributor Author

nikic commented Jun 27, 2024

Okay, it seems like the "adding member causes significant compile time regression" is some kind of gcc optimization fluke. Recently 326ba38 landed, which really shouldn't have any impact at all, but it caused this regression: http://llvm-compile-time-tracker.com/compare.php?from=ca4e5a8d6e00d8a851c3bbd01442193f97a80139&to=326ba38a991250a8587a399a260b0f7af2c9166a&stat=instructions:u

And after that change, the impact of this PR is now only barely above the significant threshold: http://llvm-compile-time-tracker.com/compare.php?from=a1ad98813006cefcdf88336db3f81a15b6bf36fb&to=0aac38777827865b40830ed8f00a3c6937dbcec4&stat=instructions:u I think that's good enough to move forward?

I suspect that there may be some kind of optimization that bails out based on function size or something and there are commits that get unlucky with it. There was some time in the past where harmless ValueTracking changes kept toggling back and forth between a slow and a fast variant, but I haven't seem this in quite a while anymore. Looks like it's back now...

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Jun 27, 2024
@goldsteinn
Copy link
Contributor

Okay, it seems like the "adding member causes significant compile time regression" is some kind of gcc optimization fluke.

Any thoughts on switching to LLVM as the default compiler?

@nikic
Copy link
Contributor Author

nikic commented Jun 27, 2024

Top 5 regressions: eastl/TestBitset.cpp.ll 7448097780 8128445212 +9.13%

I looked at the first one of these. This is a case where the computeKnownBits() calls dominate. There are a lot of selects with expensive operands.

This additional patch improves the situation for this kind of pathological case: 254fd43...00597f0 This is what @goldsteinn suggested about first checking whether there are any affected values in the select arms.

It doesn't fully avoid the regression on that test case, as many selects do have potentially affected values, but at least cuts it in half.

Do you want me to include that?

(My initial attempt was a huge regression instead: http://llvm-compile-time-tracker.com/compare.php?from=254fd43272429eac4cdd98924fab8cf5be9019c2&to=54d67e7df109e020398f7027c8df0a715c20a6cc&stat=instructions:u Turns out that having a reduced depth for phi nodes is very, very important!)

@goldsteinn
Copy link
Contributor

It doesn't fully avoid the regression on that test case, as many selects do have potentially affected values, but at least cuts it in half.

Do you want me to include that?

Unless it dramatically undermines the value of the patch, halfway is better than nothing no?

@nikic
Copy link
Contributor Author

nikic commented Jun 27, 2024

It doesn't fully avoid the regression on that test case, as many selects do have potentially affected values, but at least cuts it in half.
Do you want me to include that?

Unless it dramatically undermines the value of the patch, halfway is better than nothing no?

It shouldn't affect the results at all. I went ahead and added it.

Here are the final results for this PR: http://llvm-compile-time-tracker.com/compare.php?from=656b8f5ec4ba3fe8ec7bdef125ccd42ed43b0b16&to=14ad6079cba73d37eb8997f8f3bf576e0fd21f2a&stat=instructions%3Au

}
return any_of(I->operands(), [&](Value *Op) {
return Op->getType()->isIntOrIntVectorTy() &&
hasAffectedValue(Op, Affected, Depth + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SInce you are essentially only going to being able to do computeKnownBitsFromCond up to a depth of two, can you get away with changing your computeKnownBits calls to use MaxAnalysisRecursionDepth - 2? Or does not having the additional bits matter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you say that computeKnownBitsFromCond() will only work up to a depth of two? Or do you mean in the phi case? (The phi case is going to pass MaxAnalysisRecursionDepth - 1 as the new depth.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in, in computeKnownBits, by a depth of two, you will no longer hit any values in "Affected" (or at least rarely, if you hit any they will not be by-chance from multi-use). So it will be similiar to a normal computeKnownBits call which is not what this patch is really after.
I'm not sure if, however, if you need to normal computeKnownBits to provide some extra bits to fill in the constants.


// Ignore the case where the select arm itself is affected. These cases
// are handled more efficiently by other optimizations.
if (Affected.contains(V) && Depth != 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should probably be inverted (cheaper check on left).

nikic added 3 commits June 28, 2024 12:56
Simplify the arms of a select based on the KnownBits implied by
its condition. For now this only handles the case where the select
arm folds to a constant, but this can be generalized to handle
other patterns by using SimplifyDemandedBits instead.
@goldsteinn
Copy link
Contributor

This LGTM if the compile time regressions in @dtcxzyw's suite are mostly fixed.

@dtcxzyw
Copy link
Member

dtcxzyw commented Jun 28, 2024

Top 5 improvements:
meshlab/MarchingCubes.cpp.ll 678109485 659907628 -2.68%
c3c/tokens.c.ll 145223091 141469919 -2.58%
tev/HelpWindow.cpp.ll 1039474883 1012684821 -2.58%
faiss/HNSW.cpp.ll 3522270697 3438578766 -2.38%
rust-analyzer-rs/4ifo5x52byu175vr.ll 302363407 296249464 -2.02%
Top 5 regressions:
eastl/TestBitset.cpp.ll 7448097780 8128445212 +9.13%
hwloc/bitmap.ll 1263490817 1334914518 +5.65%
abc/rsbDec6.c.ll 1330779608 1379597304 +3.67%
gromacs/lincs.cpp.ll 4224451840 4378129884 +3.64%
z3/util.cpp.ll 141417444 146449050 +3.56%
Overall: 0.04440040%

Top 5 improvements:
postgres/unicode_norm_shlib.ll 363748825 350961218 -3.52%
meshlab/MarchingCubes.cpp.ll 675246704 656351184 -2.80%
tev/HelpWindow.cpp.ll 1029062714 1002529143 -2.58%
faiss/AdditiveQuantizer.cpp.ll 1493267399 1458937166 -2.30%
faiss/InvertedLists.cpp.ll 1529590909 1496064094 -2.19%
Top 5 regressions:
eastl/TestBitset.cpp.ll 7404525337 7721492909 +4.28%
hwloc/bitmap.ll 1254672490 1289997586 +2.82%
libwebp/upsampling.c.ll 1495289974 1528316818 +2.21%
faiss/IndexScalarQuantizer.cpp.ll 635573896 649291704 +2.16%
meshlab/meshfilter.cpp.ll 77037371771 78638508357 +2.08%
Overall: 0.04310683%

It looks better now :)

@nikic nikic force-pushed the select-known-bits branch from 05eb203 to 23fe66e Compare June 28, 2024 16:44
Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you!

@nikic nikic merged commit 77eb056 into llvm:main Jul 1, 2024
4 of 6 checks passed
@nikic nikic deleted the select-known-bits branch July 1, 2024 07:26
lravenclaw pushed a commit to lravenclaw/llvm-project that referenced this pull request Jul 3, 2024
Simplify the arms of a select based on the KnownBits implied by its condition.
For now this only handles the case where the select arm folds to a constant,
but this can be generalized to handle other patterns by using
SimplifyDemandedBits instead (in that case we would also have to limit to
non-undef conditions).

This is implemented by adding a new member to SimplifyQuery that can be used
to inject an additional condition. The affected values are pre-computed and
we don't call computeKnownBits() if the select arms don't contain affected
values. This reduces the cost in some pathological cases.
kbluck pushed a commit to kbluck/llvm-project that referenced this pull request Jul 6, 2024
Simplify the arms of a select based on the KnownBits implied by its condition.
For now this only handles the case where the select arm folds to a constant,
but this can be generalized to handle other patterns by using
SimplifyDemandedBits instead (in that case we would also have to limit to
non-undef conditions).

This is implemented by adding a new member to SimplifyQuery that can be used
to inject an additional condition. The affected values are pre-computed and
we don't call computeKnownBits() if the select arms don't contain affected
values. This reduces the cost in some pathological cases.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants