diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h index 25272e0581c93..0e02d0d5b4865 100644 --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -17,6 +17,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Analysis/SimplifyQuery.h" +#include "llvm/Analysis/WithCache.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/FMF.h" @@ -90,6 +91,12 @@ KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts, const DominatorTree *DT = nullptr, bool UseInstrInfo = true); +KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts, + unsigned Depth, const SimplifyQuery &Q); + +KnownBits computeKnownBits(const Value *V, unsigned Depth, + const SimplifyQuery &Q); + /// Compute known bits from the range metadata. /// \p KnownZero the set of bits that are known to be zero /// \p KnownOne the set of bits that are known to be one @@ -107,7 +114,8 @@ KnownBits analyzeKnownBitsFromAndXorOr( bool UseInstrInfo = true); /// Return true if LHS and RHS have no common bits set. -bool haveNoCommonBitsSet(const Value *LHS, const Value *RHS, +bool haveNoCommonBitsSet(const WithCache &LHSCache, + const WithCache &RHSCache, const SimplifyQuery &SQ); /// Return true if the given value is known to have exactly one bit set when @@ -847,9 +855,12 @@ OverflowResult computeOverflowForUnsignedMul(const Value *LHS, const Value *RHS, const SimplifyQuery &SQ); OverflowResult computeOverflowForSignedMul(const Value *LHS, const Value *RHS, const SimplifyQuery &SQ); -OverflowResult computeOverflowForUnsignedAdd(const Value *LHS, const Value *RHS, - const SimplifyQuery &SQ); -OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS, +OverflowResult +computeOverflowForUnsignedAdd(const WithCache &LHS, + const WithCache &RHS, + const SimplifyQuery &SQ); +OverflowResult computeOverflowForSignedAdd(const WithCache &LHS, + const WithCache &RHS, const SimplifyQuery &SQ); /// This version also leverages the sign bit of Add if known. OverflowResult computeOverflowForSignedAdd(const AddOperator *Add, diff --git a/llvm/include/llvm/Analysis/WithCache.h b/llvm/include/llvm/Analysis/WithCache.h new file mode 100644 index 0000000000000..8065c45738f84 --- /dev/null +++ b/llvm/include/llvm/Analysis/WithCache.h @@ -0,0 +1,71 @@ +//===- llvm/Analysis/WithCache.h - KnownBits cache for pointers -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Store a pointer to any type along with the KnownBits information for it +// that is computed lazily (if required). +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ANALYSIS_WITHCACHE_H +#define LLVM_ANALYSIS_WITHCACHE_H + +#include "llvm/IR/Value.h" +#include "llvm/Support/KnownBits.h" +#include + +namespace llvm { +struct SimplifyQuery; +KnownBits computeKnownBits(const Value *V, unsigned Depth, + const SimplifyQuery &Q); + +template class WithCache { + static_assert(std::is_pointer_v, "WithCache requires a pointer type!"); + + using UnderlyingType = std::remove_pointer_t; + constexpr static bool IsConst = std::is_const_v; + + template + using conditionally_const_t = std::conditional_t; + + using PointerType = conditionally_const_t; + using ReferenceType = conditionally_const_t; + + // Store the presence of the KnownBits information in one of the bits of + // Pointer. + // true -> present + // false -> absent + mutable PointerIntPair Pointer; + mutable KnownBits Known; + + void calculateKnownBits(const SimplifyQuery &Q) const { + Known = computeKnownBits(Pointer.getPointer(), 0, Q); + Pointer.setInt(true); + } + +public: + WithCache(PointerType Pointer) : Pointer(Pointer, false) {} + WithCache(PointerType Pointer, const KnownBits &Known) + : Pointer(Pointer, true), Known(Known) {} + + [[nodiscard]] PointerType getValue() const { return Pointer.getPointer(); } + + [[nodiscard]] const KnownBits &getKnownBits(const SimplifyQuery &Q) const { + if (!hasKnownBits()) + calculateKnownBits(Q); + return Known; + } + + [[nodiscard]] bool hasKnownBits() const { return Pointer.getInt(); } + + operator PointerType() const { return Pointer.getPointer(); } + PointerType operator->() const { return Pointer.getPointer(); } + ReferenceType operator*() const { return *Pointer.getPointer(); } +}; +} // namespace llvm + +#endif diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h index dcfcc8f41dd58..f8b3874267ded 100644 --- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h +++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h @@ -510,15 +510,18 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner { SQ.getWithInstruction(CxtI)); } - OverflowResult computeOverflowForUnsignedAdd(const Value *LHS, - const Value *RHS, - const Instruction *CxtI) const { + OverflowResult + computeOverflowForUnsignedAdd(const WithCache &LHS, + const WithCache &RHS, + const Instruction *CxtI) const { return llvm::computeOverflowForUnsignedAdd(LHS, RHS, SQ.getWithInstruction(CxtI)); } - OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS, - const Instruction *CxtI) const { + OverflowResult + computeOverflowForSignedAdd(const WithCache &LHS, + const WithCache &RHS, + const Instruction *CxtI) const { return llvm::computeOverflowForSignedAdd(LHS, RHS, SQ.getWithInstruction(CxtI)); } diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 82310444326d6..1e0281b3f1bd7 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -33,6 +33,7 @@ #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/VectorUtils.h" +#include "llvm/Analysis/WithCache.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -178,17 +179,11 @@ void llvm::computeKnownBits(const Value *V, const APInt &DemandedElts, SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo)); } -static KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts, - unsigned Depth, const SimplifyQuery &Q); - -static KnownBits computeKnownBits(const Value *V, unsigned Depth, - const SimplifyQuery &Q); - KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, bool UseInstrInfo) { - return ::computeKnownBits( + return computeKnownBits( V, Depth, SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo)); } @@ -196,13 +191,17 @@ KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, bool UseInstrInfo) { - return ::computeKnownBits( + return computeKnownBits( V, DemandedElts, Depth, SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo)); } -bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS, +bool llvm::haveNoCommonBitsSet(const WithCache &LHSCache, + const WithCache &RHSCache, const SimplifyQuery &SQ) { + const Value *LHS = LHSCache.getValue(); + const Value *RHS = RHSCache.getValue(); + assert(LHS->getType() == RHS->getType() && "LHS and RHS should have the same type"); assert(LHS->getType()->isIntOrIntVectorTy() && @@ -250,12 +249,9 @@ bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS, match(LHS, m_Not(m_c_Or(m_Specific(A), m_Specific(B))))) return true; } - IntegerType *IT = cast(LHS->getType()->getScalarType()); - KnownBits LHSKnown(IT->getBitWidth()); - KnownBits RHSKnown(IT->getBitWidth()); - ::computeKnownBits(LHS, LHSKnown, 0, SQ); - ::computeKnownBits(RHS, RHSKnown, 0, SQ); - return KnownBits::haveNoCommonBitsSet(LHSKnown, RHSKnown); + + return KnownBits::haveNoCommonBitsSet(LHSCache.getKnownBits(SQ), + RHSCache.getKnownBits(SQ)); } bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) { @@ -1784,19 +1780,19 @@ static void computeKnownBitsFromOperator(const Operator *I, /// Determine which bits of V are known to be either zero or one and return /// them. -KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts, - unsigned Depth, const SimplifyQuery &Q) { +KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts, + unsigned Depth, const SimplifyQuery &Q) { KnownBits Known(getBitWidth(V->getType(), Q.DL)); - computeKnownBits(V, DemandedElts, Known, Depth, Q); + ::computeKnownBits(V, DemandedElts, Known, Depth, Q); return Known; } /// Determine which bits of V are known to be either zero or one and return /// them. -KnownBits computeKnownBits(const Value *V, unsigned Depth, - const SimplifyQuery &Q) { +KnownBits llvm::computeKnownBits(const Value *V, unsigned Depth, + const SimplifyQuery &Q) { KnownBits Known(getBitWidth(V->getType(), Q.DL)); - computeKnownBits(V, Known, Depth, Q); + ::computeKnownBits(V, Known, Depth, Q); return Known; } @@ -6256,10 +6252,11 @@ static OverflowResult mapOverflowResult(ConstantRange::OverflowResult OR) { /// Combine constant ranges from computeConstantRange() and computeKnownBits(). static ConstantRange -computeConstantRangeIncludingKnownBits(const Value *V, bool ForSigned, +computeConstantRangeIncludingKnownBits(const WithCache &V, + bool ForSigned, const SimplifyQuery &SQ) { - KnownBits Known = ::computeKnownBits(V, /*Depth=*/0, SQ); - ConstantRange CR1 = ConstantRange::fromKnownBits(Known, ForSigned); + ConstantRange CR1 = + ConstantRange::fromKnownBits(V.getKnownBits(SQ), ForSigned); ConstantRange CR2 = computeConstantRange(V, ForSigned, SQ.IIQ.UseInstrInfo); ConstantRange::PreferredRangeType RangeType = ForSigned ? ConstantRange::Signed : ConstantRange::Unsigned; @@ -6269,8 +6266,8 @@ computeConstantRangeIncludingKnownBits(const Value *V, bool ForSigned, OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS, const Value *RHS, const SimplifyQuery &SQ) { - KnownBits LHSKnown = ::computeKnownBits(LHS, /*Depth=*/0, SQ); - KnownBits RHSKnown = ::computeKnownBits(RHS, /*Depth=*/0, SQ); + KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, SQ); + KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, SQ); ConstantRange LHSRange = ConstantRange::fromKnownBits(LHSKnown, false); ConstantRange RHSRange = ConstantRange::fromKnownBits(RHSKnown, false); return mapOverflowResult(LHSRange.unsignedMulMayOverflow(RHSRange)); @@ -6307,17 +6304,18 @@ OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS, // product is exactly the minimum negative number. // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000 // For simplicity we just check if at least one side is not negative. - KnownBits LHSKnown = ::computeKnownBits(LHS, /*Depth=*/0, SQ); - KnownBits RHSKnown = ::computeKnownBits(RHS, /*Depth=*/0, SQ); + KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, SQ); + KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, SQ); if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative()) return OverflowResult::NeverOverflows; } return OverflowResult::MayOverflow; } -OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS, - const Value *RHS, - const SimplifyQuery &SQ) { +OverflowResult +llvm::computeOverflowForUnsignedAdd(const WithCache &LHS, + const WithCache &RHS, + const SimplifyQuery &SQ) { ConstantRange LHSRange = computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ); ConstantRange RHSRange = @@ -6325,10 +6323,10 @@ OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS, return mapOverflowResult(LHSRange.unsignedAddMayOverflow(RHSRange)); } -static OverflowResult computeOverflowForSignedAdd(const Value *LHS, - const Value *RHS, - const AddOperator *Add, - const SimplifyQuery &SQ) { +static OverflowResult +computeOverflowForSignedAdd(const WithCache &LHS, + const WithCache &RHS, + const AddOperator *Add, const SimplifyQuery &SQ) { if (Add && Add->hasNoSignedWrap()) { return OverflowResult::NeverOverflows; } @@ -6944,9 +6942,10 @@ OverflowResult llvm::computeOverflowForSignedAdd(const AddOperator *Add, Add, SQ); } -OverflowResult llvm::computeOverflowForSignedAdd(const Value *LHS, - const Value *RHS, - const SimplifyQuery &SQ) { +OverflowResult +llvm::computeOverflowForSignedAdd(const WithCache &LHS, + const WithCache &RHS, + const SimplifyQuery &SQ) { return ::computeOverflowForSignedAdd(LHS, RHS, nullptr, SQ); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 44f6e37cb3b44..87181650e7587 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1566,7 +1566,8 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { return replaceInstUsesWith(I, Constant::getNullValue(I.getType())); // A+B --> A|B iff A and B have no bits set in common. - if (haveNoCommonBitsSet(LHS, RHS, SQ.getWithInstruction(&I))) + WithCache LHSCache(LHS), RHSCache(RHS); + if (haveNoCommonBitsSet(LHSCache, RHSCache, SQ.getWithInstruction(&I))) return BinaryOperator::CreateOr(LHS, RHS); if (Instruction *Ext = narrowMathIfNoOverflow(I)) @@ -1661,11 +1662,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { // willNotOverflowUnsignedAdd to reduce the number of invocations of // computeKnownBits. bool Changed = false; - if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHS, RHS, I)) { + if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHSCache, RHSCache, I)) { Changed = true; I.setHasNoSignedWrap(true); } - if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedAdd(LHS, RHS, I)) { + if (!I.hasNoUnsignedWrap() && + willNotOverflowUnsignedAdd(LHSCache, RHSCache, I)) { Changed = true; I.setHasNoUnsignedWrap(true); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 83c127a0ef012..a53d67b2899b7 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -295,13 +295,15 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final Instruction *transformSExtICmp(ICmpInst *Cmp, SExtInst &Sext); - bool willNotOverflowSignedAdd(const Value *LHS, const Value *RHS, + bool willNotOverflowSignedAdd(const WithCache &LHS, + const WithCache &RHS, const Instruction &CxtI) const { return computeOverflowForSignedAdd(LHS, RHS, &CxtI) == OverflowResult::NeverOverflows; } - bool willNotOverflowUnsignedAdd(const Value *LHS, const Value *RHS, + bool willNotOverflowUnsignedAdd(const WithCache &LHS, + const WithCache &RHS, const Instruction &CxtI) const { return computeOverflowForUnsignedAdd(LHS, RHS, &CxtI) == OverflowResult::NeverOverflows;