Skip to content

Commit

Permalink
[InstCombine] Fold A == MIN_INT ? B != MIN_INT : A < B to A < B (#…
Browse files Browse the repository at this point in the history
…120177)

This PR folds:
 `A == MIN_INT ? B != MIN_INT : A < B` to `A < B`
 `A == MAX_INT ? B != MAX_INT : A > B` to `A > B`

Proof: https://alive2.llvm.org/ce/z/bR6E2s

This helps in optimizing comparison of optional unsigned non-zero types
in rust-lang/rust#49892.

Rust compiler's current output: https://rust.godbolt.org/z/9fxfq3Gn8
  • Loading branch information
veera-sivarajan authored Dec 19, 2024
1 parent 94837c8 commit 6f8afaf
Show file tree
Hide file tree
Showing 3 changed files with 501 additions and 1 deletion.
1 change: 0 additions & 1 deletion llvm/lib/Analysis/InstructionSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7144,7 +7144,6 @@ static Value *simplifyInstructionWithOperands(Instruction *I,
NewOps[1], I->getFastMathFlags(), Q, MaxRecurse);
case Instruction::Select:
return simplifySelectInst(NewOps[0], NewOps[1], NewOps[2], Q, MaxRecurse);
break;
case Instruction::GetElementPtr: {
auto *GEPI = cast<GetElementPtrInst>(I);
return simplifyGEPInst(GEPI->getSourceElementType(), NewOps[0],
Expand Down
44 changes: 44 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1781,6 +1781,46 @@ static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI,
return nullptr;
}

/// `A == MIN_INT ? B != MIN_INT : A < B` --> `A < B`
/// `A == MAX_INT ? B != MAX_INT : A > B` --> `A > B`
static Instruction *foldSelectWithExtremeEqCond(Value *CmpLHS, Value *CmpRHS,
Value *TrueVal,
Value *FalseVal) {
Type *Ty = CmpLHS->getType();

if (Ty->isPtrOrPtrVectorTy())
return nullptr;

CmpPredicate Pred;
Value *B;

if (!match(FalseVal, m_c_ICmp(Pred, m_Specific(CmpLHS), m_Value(B))))
return nullptr;

Value *TValRHS;
if (!match(TrueVal, m_SpecificICmp(ICmpInst::ICMP_NE, m_Specific(B),
m_Value(TValRHS))))
return nullptr;

APInt C;
unsigned BitWidth = Ty->getScalarSizeInBits();

if (ICmpInst::isLT(Pred)) {
C = CmpInst::isSigned(Pred) ? APInt::getSignedMinValue(BitWidth)
: APInt::getMinValue(BitWidth);
} else if (ICmpInst::isGT(Pred)) {
C = CmpInst::isSigned(Pred) ? APInt::getSignedMaxValue(BitWidth)
: APInt::getMaxValue(BitWidth);
} else {
return nullptr;
}

if (!match(CmpRHS, m_SpecificInt(C)) || !match(TValRHS, m_SpecificInt(C)))
return nullptr;

return new ICmpInst(Pred, CmpLHS, B);
}

static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
InstCombinerImpl &IC) {
ICmpInst::Predicate Pred = ICI->getPredicate();
Expand All @@ -1795,6 +1835,10 @@ static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
if (Pred == ICmpInst::ICMP_NE)
std::swap(TrueVal, FalseVal);

if (Instruction *Res =
foldSelectWithExtremeEqCond(CmpLHS, CmpRHS, TrueVal, FalseVal))
return Res;

// Transform (X == C) ? X : Y -> (X == C) ? C : Y
// specific handling for Bitwise operation.
// x&y -> (x|y) ^ (x^y) or (x|y) & ~(x^y)
Expand Down
Loading

0 comments on commit 6f8afaf

Please sign in to comment.