Skip to content

Commit

Permalink
Fold A == MIN_INT ? B != MIN_INT : A < B to A < B
Browse files Browse the repository at this point in the history
  • Loading branch information
veera-sivarajan committed Dec 18, 2024
1 parent 6dca2fe commit 122e3eb
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 60 deletions.
1 change: 0 additions & 1 deletion llvm/lib/Analysis/InstructionSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7141,7 +7141,6 @@ static Value *simplifyInstructionWithOperands(Instruction *I,
NewOps[1], I->getFastMathFlags(), Q, MaxRecurse);
case Instruction::Select:
return simplifySelectInst(NewOps[0], NewOps[1], NewOps[2], Q, MaxRecurse);
break;
case Instruction::GetElementPtr: {
auto *GEPI = cast<GetElementPtrInst>(I);
return simplifyGEPInst(GEPI->getSourceElementType(), NewOps[0],
Expand Down
52 changes: 52 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1781,6 +1781,54 @@ static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI,
return nullptr;
}

/// `A == MIN_INT ? B != MIN_INT : A < B` --> `A < B`
/// `A == MAX_INT ? B != MAX_INT : A > B` --> `A > B`
static Value *foldSelectWithExtremeEqCond(Value *CmpLHS, Value *CmpRHS,
Value *TrueVal, Value *FalseVal,
IRBuilderBase &Builder) {
CmpPredicate Pred;
Value *A, *B;

if (!match(FalseVal, m_ICmp(Pred, m_Value(A), m_Value(B))))
return nullptr;

Type *Ty = A->getType();

if (Ty->isPtrOrPtrVectorTy())
return nullptr;

// make sure `CmpLHS` is on the LHS of `FalseVal`.
if (CmpLHS == B) {
std::swap(A, B);
Pred = CmpInst::getSwappedPredicate(Pred);
}

if (CmpLHS != A)
return nullptr;

APInt C;
unsigned BitWidth = Ty->getScalarSizeInBits();

if (ICmpInst::isLT(Pred)) {
C = CmpInst::isSigned(Pred) ? APInt::getSignedMinValue(BitWidth)
: APInt::getMinValue(BitWidth);
} else if (ICmpInst::isGT(Pred)) {
C = CmpInst::isSigned(Pred) ? APInt::getSignedMaxValue(BitWidth)
: APInt::getMaxValue(BitWidth);
} else {
return nullptr;
}

if (!match(CmpRHS, m_SpecificInt(C)))
return nullptr;

if (!match(TrueVal, m_SpecificICmp(ICmpInst::ICMP_NE, m_Specific(B),
m_SpecificInt(C))))
return nullptr;

return Builder.CreateICmp(Pred, A, B);
}

static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
InstCombinerImpl &IC) {
ICmpInst::Predicate Pred = ICI->getPredicate();
Expand All @@ -1795,6 +1843,10 @@ static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
if (Pred == ICmpInst::ICMP_NE)
std::swap(TrueVal, FalseVal);

if (Value *V = foldSelectWithExtremeEqCond(CmpLHS, CmpRHS, TrueVal, FalseVal,
IC.Builder))
return IC.replaceInstUsesWith(SI, V);

// Transform (X == C) ? X : Y -> (X == C) ? C : Y
// specific handling for Bitwise operation.
// x&y -> (x|y) ^ (x^y) or (x|y) & ~(x^y)
Expand Down
74 changes: 15 additions & 59 deletions llvm/test/Transforms/InstCombine/select-with-extreme-eq-cond.ll
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@ define i1 @compare_unsigned_min(i8 %0, i8 %1) {
; CHECK-LABEL: define i1 @compare_unsigned_min(
; CHECK-SAME: i8 [[TMP0:%.*]], i8 [[TMP1:%.*]]) {
; CHECK-NEXT: [[START:.*:]]
; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i8 [[TMP0]], 0
; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i8 [[TMP1]], 0
; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[RESULT:%.*]] = select i1 [[TMP4]], i1 [[TMP3]], i1 [[TMP2]]
; CHECK-NEXT: ret i1 [[RESULT]]
; CHECK-NEXT: ret i1 [[TMP2]]
;
start:
%2 = icmp eq i8 %0, 0
Expand All @@ -23,11 +20,8 @@ define i1 @compare_signed_min(i8 %0, i8 %1) {
; CHECK-LABEL: define i1 @compare_signed_min(
; CHECK-SAME: i8 [[TMP0:%.*]], i8 [[TMP1:%.*]]) {
; CHECK-NEXT: [[START:.*:]]
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP0]], -128
; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i8 [[TMP1]], -128
; CHECK-NEXT: [[TMP4:%.*]] = icmp slt i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[RESULT:%.*]] = select i1 [[TMP2]], i1 [[TMP3]], i1 [[TMP4]]
; CHECK-NEXT: ret i1 [[RESULT]]
; CHECK-NEXT: ret i1 [[TMP4]]
;
start:
%2 = icmp eq i8 %0, -128
Expand All @@ -41,11 +35,8 @@ define i1 @compare_unsigned_max(i8 %0, i8 %1) {
; CHECK-LABEL: define i1 @compare_unsigned_max(
; CHECK-SAME: i8 [[TMP0:%.*]], i8 [[TMP1:%.*]]) {
; CHECK-NEXT: [[START:.*:]]
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP0]], -1
; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i8 [[TMP1]], -1
; CHECK-NEXT: [[TMP4:%.*]] = icmp ugt i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[RESULT:%.*]] = select i1 [[TMP2]], i1 [[TMP3]], i1 [[TMP4]]
; CHECK-NEXT: ret i1 [[RESULT]]
; CHECK-NEXT: ret i1 [[TMP4]]
;
start:
%2 = icmp eq i8 %0, 255
Expand All @@ -59,11 +50,8 @@ define i1 @compare_signed_max(i8 %0, i8 %1) {
; CHECK-LABEL: define i1 @compare_signed_max(
; CHECK-SAME: i8 [[TMP0:%.*]], i8 [[TMP1:%.*]]) {
; CHECK-NEXT: [[START:.*:]]
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP0]], 127
; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i8 [[TMP1]], 127
; CHECK-NEXT: [[TMP4:%.*]] = icmp sgt i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[RESULT:%.*]] = select i1 [[TMP2]], i1 [[TMP3]], i1 [[TMP4]]
; CHECK-NEXT: ret i1 [[RESULT]]
; CHECK-NEXT: ret i1 [[TMP4]]
;
start:
%2 = icmp eq i8 %0, 127
Expand All @@ -77,11 +65,8 @@ define i1 @relational_cmp_unsigned_min(i8 %0, i8 %1) {
; CHECK-LABEL: define i1 @relational_cmp_unsigned_min(
; CHECK-SAME: i8 [[TMP0:%.*]], i8 [[TMP1:%.*]]) {
; CHECK-NEXT: [[START:.*:]]
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP0]], 0
; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i8 [[TMP1]], 0
; CHECK-NEXT: [[TMP4:%.*]] = icmp ult i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[RESULT:%.*]] = select i1 [[TMP2]], i1 [[TMP3]], i1 [[TMP4]]
; CHECK-NEXT: ret i1 [[RESULT]]
; CHECK-NEXT: ret i1 [[TMP4]]
;
start:
%2 = icmp ule i8 %0, 0
Expand All @@ -95,11 +80,8 @@ define i1 @relational_cmp_signed_min(i8 %0, i8 %1) {
; CHECK-LABEL: define i1 @relational_cmp_signed_min(
; CHECK-SAME: i8 [[TMP0:%.*]], i8 [[TMP1:%.*]]) {
; CHECK-NEXT: [[START:.*:]]
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP0]], -128
; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i8 [[TMP1]], -128
; CHECK-NEXT: [[TMP4:%.*]] = icmp slt i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[RESULT:%.*]] = select i1 [[TMP2]], i1 [[TMP3]], i1 [[TMP4]]
; CHECK-NEXT: ret i1 [[RESULT]]
; CHECK-NEXT: ret i1 [[TMP4]]
;
start:
%2 = icmp sle i8 %0, -128
Expand All @@ -113,11 +95,8 @@ define i1 @relational_cmp_unsigned_max(i8 %0, i8 %1) {
; CHECK-LABEL: define i1 @relational_cmp_unsigned_max(
; CHECK-SAME: i8 [[TMP0:%.*]], i8 [[TMP1:%.*]]) {
; CHECK-NEXT: [[START:.*:]]
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP0]], -1
; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i8 [[TMP1]], -1
; CHECK-NEXT: [[TMP4:%.*]] = icmp ugt i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[RESULT:%.*]] = select i1 [[TMP2]], i1 [[TMP3]], i1 [[TMP4]]
; CHECK-NEXT: ret i1 [[RESULT]]
; CHECK-NEXT: ret i1 [[TMP4]]
;
start:
%2 = icmp uge i8 %0, 255
Expand All @@ -131,11 +110,8 @@ define i1 @relational_cmp_signed_max(i8 %0, i8 %1) {
; CHECK-LABEL: define i1 @relational_cmp_signed_max(
; CHECK-SAME: i8 [[TMP0:%.*]], i8 [[TMP1:%.*]]) {
; CHECK-NEXT: [[START:.*:]]
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP0]], 127
; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i8 [[TMP1]], 127
; CHECK-NEXT: [[TMP4:%.*]] = icmp sgt i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[RESULT:%.*]] = select i1 [[TMP2]], i1 [[TMP3]], i1 [[TMP4]]
; CHECK-NEXT: ret i1 [[RESULT]]
; CHECK-NEXT: ret i1 [[TMP4]]
;
start:
%2 = icmp sge i8 %0, 127
Expand All @@ -151,11 +127,9 @@ define i1 @compare_signed_max_multiuse(i8 %0, i8 %1) {
; CHECK-LABEL: define i1 @compare_signed_max_multiuse(
; CHECK-SAME: i8 [[TMP0:%.*]], i8 [[TMP1:%.*]]) {
; CHECK-NEXT: [[START:.*:]]
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP0]], 127
; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i8 [[TMP1]], 127
; CHECK-NEXT: [[TMP4:%.*]] = icmp sgt i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: call void @use(i1 [[TMP4]])
; CHECK-NEXT: [[RESULT:%.*]] = select i1 [[TMP2]], i1 [[TMP3]], i1 [[TMP4]]
; CHECK-NEXT: [[RESULT:%.*]] = icmp sgt i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: ret i1 [[RESULT]]
;
start:
Expand All @@ -171,10 +145,7 @@ define i1 @compare_signed_min_samesign(i8 %0, i8 %1) {
; CHECK-LABEL: define i1 @compare_signed_min_samesign(
; CHECK-SAME: i8 [[TMP0:%.*]], i8 [[TMP1:%.*]]) {
; CHECK-NEXT: [[START:.*:]]
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP0]], -128
; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i8 [[TMP1]], -128
; CHECK-NEXT: [[TMP4:%.*]] = icmp samesign slt i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[RESULT:%.*]] = select i1 [[TMP2]], i1 [[TMP3]], i1 [[TMP4]]
; CHECK-NEXT: [[RESULT:%.*]] = icmp slt i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: ret i1 [[RESULT]]
;
start:
Expand All @@ -189,10 +160,7 @@ define i1 @compare_flipped(i8 %0, i8 %1) {
; CHECK-LABEL: define i1 @compare_flipped(
; CHECK-SAME: i8 [[TMP0:%.*]], i8 [[TMP1:%.*]]) {
; CHECK-NEXT: [[START:.*:]]
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP0]], 0
; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i8 [[TMP1]], 0
; CHECK-NEXT: [[TMP4:%.*]] = icmp ugt i8 [[TMP1]], [[TMP0]]
; CHECK-NEXT: [[RESULT:%.*]] = select i1 [[TMP2]], i1 [[TMP3]], i1 [[TMP4]]
; CHECK-NEXT: [[RESULT:%.*]] = icmp ult i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: ret i1 [[RESULT]]
;
start:
Expand All @@ -207,11 +175,8 @@ define i1 @compare_swapped(i8 %0, i8 %1) {
; CHECK-LABEL: define i1 @compare_swapped(
; CHECK-SAME: i8 [[TMP0:%.*]], i8 [[TMP1:%.*]]) {
; CHECK-NEXT: [[START:.*:]]
; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i8 [[TMP0]], 0
; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i8 [[TMP1]], 0
; CHECK-NEXT: [[RESULT:%.*]] = icmp ult i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[RESULT1:%.*]] = select i1 [[DOTNOT]], i1 [[TMP2]], i1 [[RESULT]]
; CHECK-NEXT: ret i1 [[RESULT1]]
; CHECK-NEXT: ret i1 [[RESULT]]
;
start:
%2 = icmp ne i8 %0, 0
Expand All @@ -225,10 +190,7 @@ define i1 @compare_swapped_flipped_unsigned_max(i8 %0, i8 %1) {
; CHECK-LABEL: define i1 @compare_swapped_flipped_unsigned_max(
; CHECK-SAME: i8 [[TMP0:%.*]], i8 [[TMP1:%.*]]) {
; CHECK-NEXT: [[START:.*:]]
; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i8 [[TMP0]], -1
; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i8 [[TMP1]], -1
; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i8 [[TMP1]], [[TMP0]]
; CHECK-NEXT: [[RESULT:%.*]] = select i1 [[DOTNOT]], i1 [[TMP3]], i1 [[TMP2]]
; CHECK-NEXT: [[RESULT:%.*]] = icmp ugt i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: ret i1 [[RESULT]]
;
start:
Expand All @@ -243,11 +205,8 @@ define i1 @compare_unsigned_min_illegal_type(i9 %0, i9 %1) {
; CHECK-LABEL: define i1 @compare_unsigned_min_illegal_type(
; CHECK-SAME: i9 [[TMP0:%.*]], i9 [[TMP1:%.*]]) {
; CHECK-NEXT: [[START:.*:]]
; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i9 [[TMP0]], 0
; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i9 [[TMP1]], 0
; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i9 [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[RESULT:%.*]] = select i1 [[TMP4]], i1 [[TMP3]], i1 [[TMP2]]
; CHECK-NEXT: ret i1 [[RESULT]]
; CHECK-NEXT: ret i1 [[TMP2]]
;
start:
%2 = icmp eq i9 %0, 0
Expand All @@ -260,11 +219,8 @@ start:
define <2 x i1> @compare_vector(<2 x i8> %x, <2 x i8> %y) {
; CHECK-LABEL: define <2 x i1> @compare_vector(
; CHECK-SAME: <2 x i8> [[X:%.*]], <2 x i8> [[Y:%.*]]) {
; CHECK-NEXT: [[TMP3:%.*]] = icmp eq <2 x i8> [[X]], zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <2 x i8> [[Y]], zeroinitializer
; CHECK-NEXT: [[TMP1:%.*]] = icmp ult <2 x i8> [[X]], [[Y]]
; CHECK-NEXT: [[RESULT:%.*]] = select <2 x i1> [[TMP3]], <2 x i1> [[TMP2]], <2 x i1> [[TMP1]]
; CHECK-NEXT: ret <2 x i1> [[RESULT]]
; CHECK-NEXT: ret <2 x i1> [[TMP1]]
;
%2 = icmp eq <2 x i8> %x, <i8 0, i8 0>
%3 = icmp ne <2 x i8> %y, <i8 0, i8 0>
Expand Down

0 comments on commit 122e3eb

Please sign in to comment.