Skip to content

Commit

Permalink
IR: introduce struct with CmpInst::Predicate and samesign (llvm#116867)
Browse files Browse the repository at this point in the history
Introduce llvm::CmpPredicate, an abstraction over a floating-point
predicate, and a pack of an integer predicate with samesign information,
in order to ease extending large portions of the codebase that take a
CmpInst::Predicate to respect the samesign flag.

We have chosen to demonstrate the utility of this new abstraction by
migrating parts of ValueTracking, InstructionSimplify, and InstCombine
from CmpInst::Predicate to llvm::CmpPredicate. There should be no
functional changes, as we don't perform any extra optimizations with
samesign in this patch, or use CmpPredicate::getMatching.

The design approach taken by this patch allows for unaudited callers of
APIs that take a llvm::CmpPredicate to silently drop the samesign
information; it does not pose a correctness issue, and allows us to
migrate the codebase piece-wise.
  • Loading branch information
artagnon authored Dec 3, 2024
1 parent f335364 commit 51a895a
Show file tree
Hide file tree
Showing 13 changed files with 228 additions and 111 deletions.
1 change: 1 addition & 0 deletions llvm/include/llvm/Analysis/InstSimplifyFolder.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/TargetFolder.h"
#include "llvm/IR/CmpPredicate.h"
#include "llvm/IR/IRBuilderFolder.h"
#include "llvm/IR/Instruction.h"

Expand Down
7 changes: 4 additions & 3 deletions llvm/include/llvm/Analysis/InstructionSimplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class DataLayout;
class DominatorTree;
class Function;
class Instruction;
class CmpPredicate;
class LoadInst;
struct LoopStandardAnalysisResults;
class Pass;
Expand Down Expand Up @@ -152,11 +153,11 @@ Value *simplifyOrInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
Value *simplifyXorInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);

/// Given operands for an ICmpInst, fold the result or return null.
Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
Value *simplifyICmpInst(CmpPredicate Pred, Value *LHS, Value *RHS,
const SimplifyQuery &Q);

/// Given operands for an FCmpInst, fold the result or return null.
Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
Value *simplifyFCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS,
FastMathFlags FMF, const SimplifyQuery &Q);

/// Given operands for a SelectInst, fold the result or return null.
Expand Down Expand Up @@ -200,7 +201,7 @@ Value *simplifyShuffleVectorInst(Value *Op0, Value *Op1, ArrayRef<int> Mask,
//=== Helper functions for higher up the class hierarchy.

/// Given operands for a CmpInst, fold the result or return null.
Value *simplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
Value *simplifyCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS,
const SimplifyQuery &Q);

/// Given operand for a UnaryOperator, fold the result or return null.
Expand Down
7 changes: 3 additions & 4 deletions llvm/include/llvm/Analysis/ValueTracking.h
Original file line number Diff line number Diff line change
Expand Up @@ -1255,8 +1255,7 @@ std::optional<bool> isImpliedCondition(const Value *LHS, const Value *RHS,
const DataLayout &DL,
bool LHSIsTrue = true,
unsigned Depth = 0);
std::optional<bool> isImpliedCondition(const Value *LHS,
CmpInst::Predicate RHSPred,
std::optional<bool> isImpliedCondition(const Value *LHS, CmpPredicate RHSPred,
const Value *RHSOp0, const Value *RHSOp1,
const DataLayout &DL,
bool LHSIsTrue = true,
Expand All @@ -1267,8 +1266,8 @@ std::optional<bool> isImpliedCondition(const Value *LHS,
std::optional<bool> isImpliedByDomCondition(const Value *Cond,
const Instruction *ContextI,
const DataLayout &DL);
std::optional<bool> isImpliedByDomCondition(CmpInst::Predicate Pred,
const Value *LHS, const Value *RHS,
std::optional<bool> isImpliedByDomCondition(CmpPredicate Pred, const Value *LHS,
const Value *RHS,
const Instruction *ContextI,
const DataLayout &DL);

Expand Down
62 changes: 62 additions & 0 deletions llvm/include/llvm/IR/CmpPredicate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//===- CmpPredicate.h - CmpInst Predicate with samesign information -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// A CmpInst::Predicate with any samesign information (applicable to ICmpInst).
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_IR_CMPPREDICATE_H
#define LLVM_IR_CMPPREDICATE_H

#include "llvm/IR/InstrTypes.h"

namespace llvm {
/// An abstraction over a floating-point predicate, and a pack of an integer
/// predicate with samesign information. Some functions in ICmpInst construct
/// and return this type in place of a Predicate.
class CmpPredicate {
CmpInst::Predicate Pred;
bool HasSameSign;

public:
/// Constructed implictly with a either Predicate and samesign information, or
/// just a Predicate, dropping samesign information.
CmpPredicate(CmpInst::Predicate Pred, bool HasSameSign = false)
: Pred(Pred), HasSameSign(HasSameSign) {
assert(!HasSameSign || CmpInst::isIntPredicate(Pred));
}

/// Implictly converts to the underlying Predicate, dropping samesign
/// information.
operator CmpInst::Predicate() const { return Pred; }

/// Query samesign information, for optimizations.
bool hasSameSign() const { return HasSameSign; }

/// Compares two CmpPredicates taking samesign into account and returns the
/// canonicalized CmpPredicate if they match. An alternative to operator==.
///
/// For example,
/// samesign ult + samesign ult -> samesign ult
/// samesign ult + ult -> ult
/// samesign ult + slt -> slt
/// ult + ult -> ult
/// ult + slt -> std::nullopt
static std::optional<CmpPredicate> getMatching(CmpPredicate A,
CmpPredicate B);

/// An operator== on the underlying Predicate.
bool operator==(CmpInst::Predicate P) const { return Pred == P; }

/// There is no operator== defined on CmpPredicate. Use getMatching instead to
/// get the canonicalized matching CmpPredicate.
bool operator==(CmpPredicate) const = delete;
};
} // namespace llvm

#endif
39 changes: 34 additions & 5 deletions llvm/include/llvm/IR/Instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "llvm/ADT/iterator.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/CmpPredicate.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GEPNoWrapFlags.h"
Expand Down Expand Up @@ -1203,6 +1204,33 @@ class ICmpInst: public CmpInst {
#endif
}

/// @returns the predicate along with samesign information.
CmpPredicate getCmpPredicate() const {
return {getPredicate(), hasSameSign()};
}

/// @returns the inverse predicate along with samesign information: static
/// variant.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred) {
return {getInversePredicate(Pred), Pred.hasSameSign()};
}

/// @returns the inverse predicate along with samesign information.
CmpPredicate getInverseCmpPredicate() const {
return getInverseCmpPredicate(getCmpPredicate());
}

/// @returns the swapped predicate along with samesign information: static
/// variant.
static CmpPredicate getSwappedCmpPredicate(CmpPredicate Pred) {
return {getSwappedPredicate(Pred), Pred.hasSameSign()};
}

/// @returns the swapped predicate.
Predicate getSwappedCmpPredicate() const {
return getSwappedPredicate(getCmpPredicate());
}

/// For example, EQ->EQ, SLE->SLE, UGT->SGT, etc.
/// @returns the predicate that would be the result if the operand were
/// regarded as signed.
Expand All @@ -1212,7 +1240,7 @@ class ICmpInst: public CmpInst {
}

/// Return the signed version of the predicate: static variant.
static Predicate getSignedPredicate(Predicate pred);
static Predicate getSignedPredicate(Predicate Pred);

/// For example, EQ->EQ, SLE->ULE, UGT->UGT, etc.
/// @returns the predicate that would be the result if the operand were
Expand All @@ -1223,14 +1251,15 @@ class ICmpInst: public CmpInst {
}

/// Return the unsigned version of the predicate: static variant.
static Predicate getUnsignedPredicate(Predicate pred);
static Predicate getUnsignedPredicate(Predicate Pred);

/// For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->Failed assert
/// For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ
/// @returns the unsigned version of the signed predicate pred or
/// the signed version of the signed predicate pred.
static Predicate getFlippedSignednessPredicate(Predicate pred);
/// Static variant.
static Predicate getFlippedSignednessPredicate(Predicate Pred);

/// For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->Failed assert
/// For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ
/// @returns the unsigned version of the signed predicate pred or
/// the signed version of the signed predicate pred.
Predicate getFlippedSignednessPredicate() const {
Expand Down
9 changes: 4 additions & 5 deletions llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
/// conditional branch or select to create a compare with a canonical
/// (inverted) predicate which is then more likely to be matched with other
/// values.
static bool isCanonicalPredicate(CmpInst::Predicate Pred) {
static bool isCanonicalPredicate(CmpPredicate Pred) {
switch (Pred) {
case CmpInst::ICMP_NE:
case CmpInst::ICMP_ULE:
Expand Down Expand Up @@ -185,10 +185,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
}

std::optional<std::pair<
CmpInst::Predicate,
Constant *>> static getFlippedStrictnessPredicateAndConstant(CmpInst::
Predicate
Pred,
CmpPredicate,
Constant *>> static getFlippedStrictnessPredicateAndConstant(CmpPredicate
Pred,
Constant *C);

static bool shouldAvoidAbsorbingNotIntoSelect(const SelectInst &SI) {
Expand Down
Loading

0 comments on commit 51a895a

Please sign in to comment.