Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Simplify select using KnownBits of condition #95923

Merged
merged 4 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions llvm/include/llvm/Analysis/SimplifyQuery.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef LLVM_ANALYSIS_SIMPLIFYQUERY_H
#define LLVM_ANALYSIS_SIMPLIFYQUERY_H

#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/IR/Operator.h"

namespace llvm {
Expand Down Expand Up @@ -57,13 +58,23 @@ struct InstrInfoQuery {
}
};

/// Evaluate query assuming this condition holds.
struct CondContext {
Value *Cond;
bool Invert = false;
SmallPtrSet<Value *, 4> AffectedValues;

CondContext(Value *Cond) : Cond(Cond) {}
};

struct SimplifyQuery {
const DataLayout &DL;
const TargetLibraryInfo *TLI = nullptr;
const DominatorTree *DT = nullptr;
AssumptionCache *AC = nullptr;
const Instruction *CxtI = nullptr;
const DomConditionCache *DC = nullptr;
const CondContext *CC = nullptr;

// Wrapper to query additional information for instructions like metadata or
// keywords like nsw, which provides conservative results if those cannot
Expand Down Expand Up @@ -113,6 +124,12 @@ struct SimplifyQuery {
Copy.DC = nullptr;
return Copy;
}

SimplifyQuery getWithCondContext(const CondContext &CC) const {
SimplifyQuery Copy(*this);
Copy.CC = &CC;
return Copy;
}
};

} // end namespace llvm
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,10 @@ static void computeKnownBitsFromCond(const Value *V, Value *Cond,

void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
unsigned Depth, const SimplifyQuery &Q) {
// Handle injected condition.
if (Q.CC && Q.CC->AffectedValues.contains(V))
computeKnownBitsFromCond(V, Q.CC->Cond, Known, Depth, Q, Q.CC->Invert);

if (!Q.CxtI)
return;

Expand Down
55 changes: 55 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3519,6 +3519,33 @@ static bool matchFMulByZeroIfResultEqZero(InstCombinerImpl &IC, Value *Cmp0,
return false;
}

/// Check whether the KnownBits of a select arm may be affected by the
/// select condition.
static bool hasAffectedValue(Value *V, SmallPtrSetImpl<Value *> &Affected,
unsigned Depth) {
if (Depth == MaxAnalysisRecursionDepth)
return false;

// Ignore the case where the select arm itself is affected. These cases
// are handled more efficiently by other optimizations.
if (Depth != 0 && Affected.contains(V))
return true;

if (auto *I = dyn_cast<Instruction>(V)) {
if (isa<PHINode>(I)) {
if (Depth == MaxAnalysisRecursionDepth - 1)
return false;
Depth = MaxAnalysisRecursionDepth - 2;
}
return any_of(I->operands(), [&](Value *Op) {
return Op->getType()->isIntOrIntVectorTy() &&
hasAffectedValue(Op, Affected, Depth + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SInce you are essentially only going to being able to do computeKnownBitsFromCond up to a depth of two, can you get away with changing your computeKnownBits calls to use MaxAnalysisRecursionDepth - 2? Or does not having the additional bits matter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you say that computeKnownBitsFromCond() will only work up to a depth of two? Or do you mean in the phi case? (The phi case is going to pass MaxAnalysisRecursionDepth - 1 as the new depth.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in, in computeKnownBits, by a depth of two, you will no longer hit any values in "Affected" (or at least rarely, if you hit any they will not be by-chance from multi-use). So it will be similiar to a normal computeKnownBits call which is not what this patch is really after.
I'm not sure if, however, if you need to normal computeKnownBits to provide some extra bits to fill in the constants.

});
}

return false;
}

Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
Expand Down Expand Up @@ -4016,5 +4043,33 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (CondVal->getType() == SI.getType() && isKnownInversion(FalseVal, TrueVal))
return BinaryOperator::CreateXor(CondVal, FalseVal);

if (SelType->isIntOrIntVectorTy() &&
(!isa<Constant>(TrueVal) || !isa<Constant>(FalseVal))) {
// Try to simplify select arms based on KnownBits implied by the condition.
CondContext CC(CondVal);
findValuesAffectedByCondition(CondVal, /*IsAssume=*/false, [&](Value *V) {
CC.AffectedValues.insert(V);
});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe aggressive inst combine is the right place for this?
Maybe we could also track select instructions in DomConditionCache (although that would req some more complex logic to keep it updated as we created/delect select instructions).

Also, think for select you can at least limit V to values that are actually used by select arms.
Maybe create a set of uses up to Depth = 6 for each of the arms and only add affected values if they also hit that set? Otherwise can't really imagine this helping.

You might also try early out of !isa<Constant>(TrueVal) && !isa<Constant>(FalseVal) to avoid unnecessary setup.

Finally, maybe try simplifying both True and False arm in one go instead of essentially doubling the number of trips.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, think you really should only add V if its TrueArm/FalseArm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in the PR description, most of the overhead comes from simply adding a member to SimplifyQuery rather than anything happening in InstCombine. These are the numbers just for adding a dummy member: https://llvm-compile-time-tracker.com/compare.php?from=3ca17443ef4af21bdb1f3b4fbcfff672cbc6176c&to=e7a64d837a061d6afeac9c0f06c4827998d43561&stat=instructions:u

As such, I don't think we'll get any substantial improvement out of changing how exactly the KnownBits calculation is done.

Maybe aggressive inst combine is the right place for this?

This would prevent extending this to use SimplifyDemandedBits, which enabled simplifying the expression without folding to constants (nikic@1b8edbd). This will allow us to subsume special cases like #92658.

Maybe we could also track select instructions in DomConditionCache (although that would req some more complex logic to keep it updated as we created/delect select instructions).

I think managing invalidation for this would be quite tricky, and this approach would be inherently limited to one-use chains to the select only.

Actually, think you really should only add V if its TrueArm/FalseArm.

Why? V may be used by TrueArm/FalseArm recursively.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in the PR description, most of the overhead comes from simply adding a member to SimplifyQuery rather than anything happening in InstCombine. These are the numbers just for adding a dummy member: https://llvm-compile-time-tracker.com/compare.php?from=3ca17443ef4af21bdb1f3b4fbcfff672cbc6176c&to=e7a64d837a061d6afeac9c0f06c4827998d43561&stat=instructions:u

As such, I don't think we'll get any substantial improvement out of changing how exactly the KnownBits calculation is done.

I see, although a bit more ugly, what about adding an additional optional argument to computeKnownBits/computeKnownBitsFromContext?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in the PR description, most of the overhead comes from simply adding a member to SimplifyQuery rather than anything happening in InstCombine. These are the numbers just for adding a dummy member: https://llvm-compile-time-tracker.com/compare.php?from=3ca17443ef4af21bdb1f3b4fbcfff672cbc6176c&to=e7a64d837a061d6afeac9c0f06c4827998d43561&stat=instructions:u
As such, I don't think we'll get any substantial improvement out of changing how exactly the KnownBits calculation is done.

I see, although a bit more ugly, what about adding an additional optional argument to computeKnownBits/computeKnownBitsFromContext?

Also, esp for ThinLTO, seems to be substantial enough difference to justify saving an iteration of InstCombine IMO. Don't think it will really affect code complexity much.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get what you refer to with "saving an iteration of InstCombine" -- how can we save an InstCombine iteration?

For the record, this is the impact of the KnownBits calculation (as opposed to the extra member and the affected value calculation): http://llvm-compile-time-tracker.com/compare.php?from=c9edd08d49137e47fe3acda3ea5b2bae563cb3ac&to=db730283063dff0dc72338413f346bd0174dafdd&stat=instructions:u

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get what you refer to with "saving an iteration of InstCombine" -- how can we save an InstCombine iteration?

I mean do the full transform in a single shot as opposed to going through the work list twice.

Iteration was wrong word to us there given it's other meaning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean do the full transform in a single shot as opposed to going through the work list twice.

Do you mean the case where it's possible to simplify both select operands? If so, this should be very rare, and as such not worth optimizing.

SimplifyQuery Q = SQ.getWithInstruction(&SI).getWithCondContext(CC);
if (!CC.AffectedValues.empty()) {
if (!isa<Constant>(TrueVal) &&
hasAffectedValue(TrueVal, CC.AffectedValues, /*Depth=*/0)) {
KnownBits Known = llvm::computeKnownBits(TrueVal, /*Depth=*/0, Q);
if (Known.isConstant())
return replaceOperand(SI, 1,
ConstantInt::get(SelType, Known.getConstant()));
}

CC.Invert = true;
if (!isa<Constant>(FalseVal) &&
hasAffectedValue(FalseVal, CC.AffectedValues, /*Depth=*/0)) {
KnownBits Known = llvm::computeKnownBits(FalseVal, /*Depth=*/0, Q);
if (Known.isConstant())
return replaceOperand(SI, 2,
ConstantInt::get(SelType, Known.getConstant()));
}
}
}

return nullptr;
}
5 changes: 1 addition & 4 deletions llvm/test/Transforms/InstCombine/select-binop-cmp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -571,10 +571,7 @@ define <2 x i8> @select_xor_icmp_vec_bad(<2 x i8> %x, <2 x i8> %y, <2 x i8> %z)

define <2 x i32> @vec_select_no_equivalence(<2 x i32> %x) {
; CHECK-LABEL: @vec_select_no_equivalence(
; CHECK-NEXT: [[X10:%.*]] = shufflevector <2 x i32> [[X:%.*]], <2 x i32> poison, <2 x i32> <i32 1, i32 0>
; CHECK-NEXT: [[COND:%.*]] = icmp eq <2 x i32> [[X]], zeroinitializer
; CHECK-NEXT: [[S:%.*]] = select <2 x i1> [[COND]], <2 x i32> [[X10]], <2 x i32> [[X]]
; CHECK-NEXT: ret <2 x i32> [[S]]
; CHECK-NEXT: ret <2 x i32> [[X:%.*]]
;
%x10 = shufflevector <2 x i32> %x, <2 x i32> undef, <2 x i32> <i32 1, i32 0>
%cond = icmp eq <2 x i32> %x, zeroinitializer
Expand Down
8 changes: 3 additions & 5 deletions llvm/test/Transforms/InstCombine/select-of-bittest.ll
Original file line number Diff line number Diff line change
Expand Up @@ -588,11 +588,9 @@ define i32 @n4(i32 %arg) {

define i32 @n5(i32 %arg) {
; CHECK-LABEL: @n5(
; CHECK-NEXT: [[T:%.*]] = and i32 [[ARG:%.*]], 2
; CHECK-NEXT: [[T1:%.*]] = icmp eq i32 [[T]], 0
; CHECK-NEXT: [[T2:%.*]] = and i32 [[ARG]], 2
; CHECK-NEXT: [[T3:%.*]] = select i1 [[T1]], i32 [[T2]], i32 1
; CHECK-NEXT: ret i32 [[T3]]
; CHECK-NEXT: [[T:%.*]] = lshr i32 [[ARG:%.*]], 1
; CHECK-NEXT: [[T_LOBIT:%.*]] = and i32 [[T]], 1
; CHECK-NEXT: ret i32 [[T_LOBIT]]
;
%t = and i32 %arg, 2
%t1 = icmp eq i32 %t, 0
Expand Down
22 changes: 6 additions & 16 deletions llvm/test/Transforms/InstCombine/select.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3807,9 +3807,8 @@ define i32 @src_and_eq_neg1_or_xor(i32 %x, i32 %y) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[AND:%.*]] = and i32 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[AND]], -1
; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y]], [[X]]
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y]], [[X]]
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[OR]], i32 [[XOR]]
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 -1, i32 [[XOR]]
; CHECK-NEXT: ret i32 [[COND]]
;
entry:
Expand All @@ -3827,9 +3826,8 @@ define i32 @src_and_eq_neg1_xor_or(i32 %x, i32 %y) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[AND:%.*]] = and i32 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[AND]], -1
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y]], [[X]]
; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y]], [[X]]
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[XOR]], i32 [[OR]]
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[OR]]
; CHECK-NEXT: ret i32 [[COND]]
;
entry:
Expand Down Expand Up @@ -3942,9 +3940,8 @@ define i32 @src_or_eq_0_and_xor(i32 %x, i32 %y) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[OR]], 0
; CHECK-NEXT: [[AND:%.*]] = and i32 [[Y]], [[X]]
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y]], [[X]]
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[AND]], i32 [[XOR]]
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[XOR]]
; CHECK-NEXT: ret i32 [[COND]]
;
entry:
Expand All @@ -3962,9 +3959,8 @@ define i32 @src_or_eq_0_xor_and(i32 %x, i32 %y) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[OR]], 0
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y]], [[X]]
; CHECK-NEXT: [[AND:%.*]] = and i32 [[Y]], [[X]]
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[XOR]], i32 [[AND]]
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[AND]]
; CHECK-NEXT: ret i32 [[COND]]
;
entry:
Expand Down Expand Up @@ -4474,10 +4470,7 @@ define i32 @src_no_trans_select_or_eq0_or_xor(i32 %x, i32 %y) {
define i32 @src_no_trans_select_or_eq0_and_or(i32 %x, i32 %y) {
; CHECK-LABEL: @src_no_trans_select_or_eq0_and_or(
; CHECK-NEXT: [[OR:%.*]] = or i32 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: [[OR0:%.*]] = icmp eq i32 [[OR]], 0
; CHECK-NEXT: [[AND:%.*]] = and i32 [[X]], [[Y]]
; CHECK-NEXT: [[COND:%.*]] = select i1 [[OR0]], i32 [[AND]], i32 [[OR]]
; CHECK-NEXT: ret i32 [[COND]]
; CHECK-NEXT: ret i32 [[OR]]
;
%or = or i32 %x, %y
%or0 = icmp eq i32 %or, 0
Expand All @@ -4489,10 +4482,7 @@ define i32 @src_no_trans_select_or_eq0_and_or(i32 %x, i32 %y) {
define i32 @src_no_trans_select_or_eq0_xor_or(i32 %x, i32 %y) {
; CHECK-LABEL: @src_no_trans_select_or_eq0_xor_or(
; CHECK-NEXT: [[OR:%.*]] = or i32 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: [[OR0:%.*]] = icmp eq i32 [[OR]], 0
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[X]], [[Y]]
; CHECK-NEXT: [[COND:%.*]] = select i1 [[OR0]], i32 [[XOR]], i32 [[OR]]
; CHECK-NEXT: ret i32 [[COND]]
; CHECK-NEXT: ret i32 [[OR]]
;
%or = or i32 %x, %y
%or0 = icmp eq i32 %or, 0
Expand Down
Loading
Loading