Skip to content

Commit

Permalink
PatternMatch: migrate to CmpPredicate (llvm#118534)
Browse files Browse the repository at this point in the history
With the introduction of CmpPredicate in 51a895a (IR: introduce struct
with CmpInst::Predicate and samesign), PatternMatch is one of the first
key pieces of infrastructure that must be updated to match a CmpInst
respecting samesign information. Implement this change to Cmp-matchers.

This is a preparatory step in migrating the codebase over to
CmpPredicate. Since we no functional changes are desired at this stage,
we have chosen not to migrate CmpPredicate::operator==(CmpPredicate)
calls to use CmpPredicate::getMatching(), as that would have visible
impact on tests that are not yet written: instead, we call
CmpPredicate::operator==(Predicate), preserving the old behavior, while
also inserting a few FIXME comments for follow-ups.

Change-Id: I33f52609ffc5092c200780e93639b880f4ab423a
  • Loading branch information
artagnon authored and searlmc1 committed Dec 16, 2024
1 parent d58046d commit 629442a
Show file tree
Hide file tree
Showing 36 changed files with 232 additions and 192 deletions.
21 changes: 21 additions & 0 deletions llvm/include/llvm/IR/CmpPredicate.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class CmpPredicate {
bool HasSameSign;

public:
/// Default constructor.
CmpPredicate() : Pred(CmpInst::BAD_ICMP_PREDICATE), HasSameSign(false) {}

/// Constructed implictly with a either Predicate and samesign information, or
/// just a Predicate, dropping samesign information.
CmpPredicate(CmpInst::Predicate Pred, bool HasSameSign = false)
Expand Down Expand Up @@ -52,11 +55,29 @@ class CmpPredicate {

/// An operator== on the underlying Predicate.
bool operator==(CmpInst::Predicate P) const { return Pred == P; }
bool operator!=(CmpInst::Predicate P) const { return Pred != P; }

/// There is no operator== defined on CmpPredicate. Use getMatching instead to
/// get the canonicalized matching CmpPredicate.
bool operator==(CmpPredicate) const = delete;
bool operator!=(CmpPredicate) const = delete;

/// Do a ICmpInst::getCmpPredicate() or CmpInst::getPredicate(), as
/// appropriate.
static CmpPredicate get(const CmpInst *Cmp);

/// Get the swapped predicate of a CmpPredicate.
static CmpPredicate getSwapped(CmpPredicate P);

/// Get the swapped predicate of a CmpInst.
static CmpPredicate getSwapped(const CmpInst *Cmp);

/// Provided to facilitate storing a CmpPredicate in data structures that
/// require hashing.
friend hash_code hash_value(const CmpPredicate &Arg); // NOLINT
};

[[nodiscard]] hash_code hash_value(const CmpPredicate &Arg);
} // namespace llvm

#endif
106 changes: 50 additions & 56 deletions llvm/include/llvm/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ inline api_pred_ty<is_lowbit_mask_or_zero> m_LowBitMaskOrZero(const APInt *&V) {
}

struct icmp_pred_with_threshold {
ICmpInst::Predicate Pred;
CmpPredicate Pred;
const APInt *Thr;
bool isValue(const APInt &C) { return ICmpInst::compare(C, *Thr, Pred); }
};
Expand Down Expand Up @@ -1557,16 +1557,16 @@ template <typename T> inline Exact_match<T> m_Exact(const T &SubPattern) {
// Matchers for CmpInst classes
//

template <typename LHS_t, typename RHS_t, typename Class, typename PredicateTy,
template <typename LHS_t, typename RHS_t, typename Class,
bool Commutable = false>
struct CmpClass_match {
PredicateTy *Predicate;
CmpPredicate *Predicate;
LHS_t L;
RHS_t R;

// The evaluation order is always stable, regardless of Commutability.
// The LHS is always matched first.
CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS)
CmpClass_match(CmpPredicate &Pred, const LHS_t &LHS, const RHS_t &RHS)
: Predicate(&Pred), L(LHS), R(RHS) {}
CmpClass_match(const LHS_t &LHS, const RHS_t &RHS)
: Predicate(nullptr), L(LHS), R(RHS) {}
Expand All @@ -1575,12 +1575,13 @@ struct CmpClass_match {
if (auto *I = dyn_cast<Class>(V)) {
if (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) {
if (Predicate)
*Predicate = I->getPredicate();
*Predicate = CmpPredicate::get(I);
return true;
} else if (Commutable && L.match(I->getOperand(1)) &&
R.match(I->getOperand(0))) {
}
if (Commutable && L.match(I->getOperand(1)) &&
R.match(I->getOperand(0))) {
if (Predicate)
*Predicate = I->getSwappedPredicate();
*Predicate = CmpPredicate::getSwapped(I);
return true;
}
}
Expand All @@ -1589,60 +1590,58 @@ struct CmpClass_match {
};

template <typename LHS, typename RHS>
inline CmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>
m_Cmp(CmpInst::Predicate &Pred, const LHS &L, const RHS &R) {
return CmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>(Pred, L, R);
inline CmpClass_match<LHS, RHS, CmpInst> m_Cmp(CmpPredicate &Pred, const LHS &L,
const RHS &R) {
return CmpClass_match<LHS, RHS, CmpInst>(Pred, L, R);
}

template <typename LHS, typename RHS>
inline CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>
m_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R) {
return CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>(Pred, L, R);
inline CmpClass_match<LHS, RHS, ICmpInst> m_ICmp(CmpPredicate &Pred,
const LHS &L, const RHS &R) {
return CmpClass_match<LHS, RHS, ICmpInst>(Pred, L, R);
}

template <typename LHS, typename RHS>
inline CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>
m_FCmp(FCmpInst::Predicate &Pred, const LHS &L, const RHS &R) {
return CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>(Pred, L, R);
inline CmpClass_match<LHS, RHS, FCmpInst> m_FCmp(CmpPredicate &Pred,
const LHS &L, const RHS &R) {
return CmpClass_match<LHS, RHS, FCmpInst>(Pred, L, R);
}

template <typename LHS, typename RHS>
inline CmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>
m_Cmp(const LHS &L, const RHS &R) {
return CmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>(L, R);
inline CmpClass_match<LHS, RHS, CmpInst> m_Cmp(const LHS &L, const RHS &R) {
return CmpClass_match<LHS, RHS, CmpInst>(L, R);
}

template <typename LHS, typename RHS>
inline CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>
m_ICmp(const LHS &L, const RHS &R) {
return CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>(L, R);
inline CmpClass_match<LHS, RHS, ICmpInst> m_ICmp(const LHS &L, const RHS &R) {
return CmpClass_match<LHS, RHS, ICmpInst>(L, R);
}

template <typename LHS, typename RHS>
inline CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>
m_FCmp(const LHS &L, const RHS &R) {
return CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>(L, R);
inline CmpClass_match<LHS, RHS, FCmpInst> m_FCmp(const LHS &L, const RHS &R) {
return CmpClass_match<LHS, RHS, FCmpInst>(L, R);
}

// Same as CmpClass, but instead of saving Pred as out output variable, match a
// specific input pred for equality.
template <typename LHS_t, typename RHS_t, typename Class, typename PredicateTy,
template <typename LHS_t, typename RHS_t, typename Class,
bool Commutable = false>
struct SpecificCmpClass_match {
const PredicateTy Predicate;
const CmpPredicate Predicate;
LHS_t L;
RHS_t R;

SpecificCmpClass_match(PredicateTy Pred, const LHS_t &LHS, const RHS_t &RHS)
SpecificCmpClass_match(CmpPredicate Pred, const LHS_t &LHS, const RHS_t &RHS)
: Predicate(Pred), L(LHS), R(RHS) {}

template <typename OpTy> bool match(OpTy *V) {
if (auto *I = dyn_cast<Class>(V)) {
if (I->getPredicate() == Predicate && L.match(I->getOperand(0)) &&
R.match(I->getOperand(1)))
if (CmpPredicate::getMatching(CmpPredicate::get(I), Predicate) &&
L.match(I->getOperand(0)) && R.match(I->getOperand(1)))
return true;
if constexpr (Commutable) {
if (I->getPredicate() == Class::getSwappedPredicate(Predicate) &&
if (CmpPredicate::getMatching(CmpPredicate::get(I),
CmpPredicate::getSwapped(Predicate)) &&
L.match(I->getOperand(1)) && R.match(I->getOperand(0)))
return true;
}
Expand All @@ -1653,31 +1652,27 @@ struct SpecificCmpClass_match {
};

template <typename LHS, typename RHS>
inline SpecificCmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>
m_SpecificCmp(CmpInst::Predicate MatchPred, const LHS &L, const RHS &R) {
return SpecificCmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>(
MatchPred, L, R);
inline SpecificCmpClass_match<LHS, RHS, CmpInst>
m_SpecificCmp(CmpPredicate MatchPred, const LHS &L, const RHS &R) {
return SpecificCmpClass_match<LHS, RHS, CmpInst>(MatchPred, L, R);
}

template <typename LHS, typename RHS>
inline SpecificCmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>
m_SpecificICmp(ICmpInst::Predicate MatchPred, const LHS &L, const RHS &R) {
return SpecificCmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>(
MatchPred, L, R);
inline SpecificCmpClass_match<LHS, RHS, ICmpInst>
m_SpecificICmp(CmpPredicate MatchPred, const LHS &L, const RHS &R) {
return SpecificCmpClass_match<LHS, RHS, ICmpInst>(MatchPred, L, R);
}

template <typename LHS, typename RHS>
inline SpecificCmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>
m_c_SpecificICmp(ICmpInst::Predicate MatchPred, const LHS &L, const RHS &R) {
return SpecificCmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>(
MatchPred, L, R);
inline SpecificCmpClass_match<LHS, RHS, ICmpInst, true>
m_c_SpecificICmp(CmpPredicate MatchPred, const LHS &L, const RHS &R) {
return SpecificCmpClass_match<LHS, RHS, ICmpInst, true>(MatchPred, L, R);
}

template <typename LHS, typename RHS>
inline SpecificCmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>
m_SpecificFCmp(FCmpInst::Predicate MatchPred, const LHS &L, const RHS &R) {
return SpecificCmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>(
MatchPred, L, R);
inline SpecificCmpClass_match<LHS, RHS, FCmpInst>
m_SpecificFCmp(CmpPredicate MatchPred, const LHS &L, const RHS &R) {
return SpecificCmpClass_match<LHS, RHS, FCmpInst>(MatchPred, L, R);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2468,7 +2463,7 @@ struct UAddWithOverflow_match {

template <typename OpTy> bool match(OpTy *V) {
Value *ICmpLHS, *ICmpRHS;
ICmpInst::Predicate Pred;
CmpPredicate Pred;
if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V))
return false;

Expand Down Expand Up @@ -2738,16 +2733,15 @@ inline AnyBinaryOp_match<LHS, RHS, true> m_c_BinOp(const LHS &L, const RHS &R) {
/// Matches an ICmp with a predicate over LHS and RHS in either order.
/// Swaps the predicate if operands are commuted.
template <typename LHS, typename RHS>
inline CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>
m_c_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R) {
return CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>(Pred, L,
R);
inline CmpClass_match<LHS, RHS, ICmpInst, true>
m_c_ICmp(CmpPredicate &Pred, const LHS &L, const RHS &R) {
return CmpClass_match<LHS, RHS, ICmpInst, true>(Pred, L, R);
}

template <typename LHS, typename RHS>
inline CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>
m_c_ICmp(const LHS &L, const RHS &R) {
return CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>(L, R);
inline CmpClass_match<LHS, RHS, ICmpInst, true> m_c_ICmp(const LHS &L,
const RHS &R) {
return CmpClass_match<LHS, RHS, ICmpInst, true>(L, R);
}

/// Matches a specific opcode with LHS and RHS in either order.
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Analysis/IVDescriptors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ RecurrenceDescriptor::isAnyOfPattern(Loop *Loop, PHINode *OrigPhi,
Instruction *I, InstDesc &Prev) {
// We must handle the select(cmp(),x,y) as a single instruction. Advance to
// the select.
CmpInst::Predicate Pred;
CmpPredicate Pred;
if (match(I, m_OneUse(m_Cmp(Pred, m_Value(), m_Value())))) {
if (auto *Select = dyn_cast<SelectInst>(*I->user_begin()))
return InstDesc(Select, Prev.getRecKind());
Expand Down Expand Up @@ -759,7 +759,7 @@ RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind,

// We must handle the select(cmp()) as a single instruction. Advance to the
// select.
CmpInst::Predicate Pred;
CmpPredicate Pred;
if (match(I, m_OneUse(m_Cmp(Pred, m_Value(), m_Value())))) {
if (auto *Select = dyn_cast<SelectInst>(*I->user_begin()))
return InstDesc(Select, Prev.getRecKind());
Expand Down
16 changes: 8 additions & 8 deletions llvm/lib/Analysis/InstructionSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1500,12 +1500,12 @@ static Value *simplifyUnsignedRangeCheck(ICmpInst *ZeroICmp,
const SimplifyQuery &Q) {
Value *X, *Y;

ICmpInst::Predicate EqPred;
CmpPredicate EqPred;
if (!match(ZeroICmp, m_ICmp(EqPred, m_Value(Y), m_Zero())) ||
!ICmpInst::isEquality(EqPred))
return nullptr;

ICmpInst::Predicate UnsignedPred;
CmpPredicate UnsignedPred;

Value *A, *B;
// Y = (A - B);
Expand Down Expand Up @@ -1644,7 +1644,7 @@ static Value *simplifyAndOrOfICmpsWithConstants(ICmpInst *Cmp0, ICmpInst *Cmp1,
static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1,
const InstrInfoQuery &IIQ) {
// (icmp (add V, C0), C1) & (icmp V, C0)
ICmpInst::Predicate Pred0, Pred1;
CmpPredicate Pred0, Pred1;
const APInt *C0, *C1;
Value *V;
if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1))))
Expand Down Expand Up @@ -1691,7 +1691,7 @@ static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1,
/// Try to simplify and/or of icmp with ctpop intrinsic.
static Value *simplifyAndOrOfICmpsWithCtpop(ICmpInst *Cmp0, ICmpInst *Cmp1,
bool IsAnd) {
ICmpInst::Predicate Pred0, Pred1;
CmpPredicate Pred0, Pred1;
Value *X;
const APInt *C;
if (!match(Cmp0, m_ICmp(Pred0, m_Intrinsic<Intrinsic::ctpop>(m_Value(X)),
Expand Down Expand Up @@ -1735,7 +1735,7 @@ static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1,
static Value *simplifyOrOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1,
const InstrInfoQuery &IIQ) {
// (icmp (add V, C0), C1) | (icmp V, C0)
ICmpInst::Predicate Pred0, Pred1;
CmpPredicate Pred0, Pred1;
const APInt *C0, *C1;
Value *V;
if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1))))
Expand Down Expand Up @@ -1891,7 +1891,7 @@ static Value *simplifyAndOrWithICmpEq(unsigned Opcode, Value *Op0, Value *Op1,
unsigned MaxRecurse) {
assert((Opcode == Instruction::And || Opcode == Instruction::Or) &&
"Must be and/or");
ICmpInst::Predicate Pred;
CmpPredicate Pred;
Value *A, *B;
if (!match(Op0, m_ICmp(Pred, m_Value(A), m_Value(B))) ||
!ICmpInst::isEquality(Pred))
Expand Down Expand Up @@ -4614,7 +4614,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
Value *FalseVal,
const SimplifyQuery &Q,
unsigned MaxRecurse) {
ICmpInst::Predicate Pred;
CmpPredicate Pred;
Value *CmpLHS, *CmpRHS;
if (!match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS))))
return nullptr;
Expand Down Expand Up @@ -4738,7 +4738,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
static Value *simplifySelectWithFCmp(Value *Cond, Value *T, Value *F,
const SimplifyQuery &Q,
unsigned MaxRecurse) {
FCmpInst::Predicate Pred;
CmpPredicate Pred;
Value *CmpLHS, *CmpRHS;
if (!match(Cond, m_FCmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS))))
return nullptr;
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Analysis/OverflowInstAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using namespace llvm::PatternMatch;

bool llvm::isCheckForZeroAndMulWithOverflow(Value *Op0, Value *Op1, bool IsAnd,
Use *&Y) {
ICmpInst::Predicate Pred;
CmpPredicate Pred;
Value *X, *NotOp1;
int XIdx;
IntrinsicInst *II;
Expand Down
Loading

0 comments on commit 629442a

Please sign in to comment.