|
| 1 | +//===-- NullCheckElimination.cpp - Null Check Elimination Pass ------------===// |
| 2 | +// |
| 3 | +// The LLVM Compiler Infrastructure |
| 4 | +// |
| 5 | +// This file is distributed under the University of Illinois Open Source |
| 6 | +// License. See LICENSE.TXT for details. |
| 7 | +// |
| 8 | +//===----------------------------------------------------------------------===// |
| 9 | + |
| 10 | +#include "llvm/Transforms/Scalar.h" |
| 11 | +#include "llvm/ADT/DenseSet.h" |
| 12 | +#include "llvm/ADT/SmallPtrSet.h" |
| 13 | +#include "llvm/ADT/Statistic.h" |
| 14 | +#include "llvm/IR/Constants.h" |
| 15 | +#include "llvm/IR/Function.h" |
| 16 | +#include "llvm/IR/Instructions.h" |
| 17 | +#include "llvm/Pass.h" |
| 18 | +using namespace llvm; |
| 19 | + |
| 20 | +#define DEBUG_TYPE "null-check-elimination" |
| 21 | + |
| 22 | +namespace { |
| 23 | + struct NullCheckElimination : public FunctionPass { |
| 24 | + static char ID; |
| 25 | + NullCheckElimination() : FunctionPass(ID) { |
| 26 | + initializeNullCheckEliminationPass(*PassRegistry::getPassRegistry()); |
| 27 | + } |
| 28 | + bool runOnFunction(Function &F) override; |
| 29 | + |
| 30 | + void getAnalysisUsage(AnalysisUsage &AU) const override { |
| 31 | + AU.setPreservesCFG(); |
| 32 | + } |
| 33 | + |
| 34 | + private: |
| 35 | + static const unsigned kPhiLimit = 16; |
| 36 | + typedef SmallPtrSet<PHINode*, kPhiLimit> SmallPhiSet; |
| 37 | + enum NullCheckResult { |
| 38 | + NotNullCheck, |
| 39 | + NullCheckEq, |
| 40 | + NullCheckNe, |
| 41 | + }; |
| 42 | + |
| 43 | + bool isNonNullOrPoisonPhi(SmallPhiSet *VisitedPhis, PHINode*); |
| 44 | + |
| 45 | + NullCheckResult isCmpNullCheck(ICmpInst*); |
| 46 | + std::pair<Use*, NullCheckResult> findNullCheck(Use*); |
| 47 | + |
| 48 | + bool blockContainsLoadDerivedFrom(BasicBlock*, Value*); |
| 49 | + |
| 50 | + DenseSet<Value*> NonNullOrPoisonValues; |
| 51 | + }; |
| 52 | +} |
| 53 | + |
| 54 | +char NullCheckElimination::ID = 0; |
| 55 | +INITIALIZE_PASS_BEGIN(NullCheckElimination, |
| 56 | + "null-check-elimination", |
| 57 | + "Null Check Elimination", |
| 58 | + false, false) |
| 59 | +INITIALIZE_PASS_END(NullCheckElimination, |
| 60 | + "null-check-elimination", |
| 61 | + "Null Check Elimination", |
| 62 | + false, false) |
| 63 | + |
| 64 | +FunctionPass *llvm::createNullCheckEliminationPass() { |
| 65 | + return new NullCheckElimination(); |
| 66 | +} |
| 67 | + |
| 68 | +bool NullCheckElimination::runOnFunction(Function &F) { |
| 69 | + if (skipOptnoneFunction(F)) |
| 70 | + return false; |
| 71 | + |
| 72 | + bool Changed = false; |
| 73 | + |
| 74 | + // Collect argumetns with the `nonnull` attribute. |
| 75 | + for (auto &Arg : F.args()) { |
| 76 | + if (Arg.hasNonNullAttr()) |
| 77 | + NonNullOrPoisonValues.insert(&Arg); |
| 78 | + } |
| 79 | + |
| 80 | + // Collect instructions that definitely produce nonnull-or-poison values. |
| 81 | + // At the moment, this is restricted to inbounds GEPs. It would be slightly |
| 82 | + // more difficult to include uses of values dominated by a null check, since |
| 83 | + // then we would have to consider uses instead of mere values. |
| 84 | + for (auto &BB : F) { |
| 85 | + for (auto &I : BB) { |
| 86 | + if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { |
| 87 | + if (GEP->isInBounds()) { |
| 88 | + NonNullOrPoisonValues.insert(GEP); |
| 89 | + } |
| 90 | + } |
| 91 | + } |
| 92 | + } |
| 93 | + |
| 94 | + // Find phis that are derived entirely from nonnull-or-poison values, |
| 95 | + // including other phis that are themselves derived entirely from these |
| 96 | + // values. |
| 97 | + for (auto &BB : F) { |
| 98 | + for (auto &I : BB) { |
| 99 | + auto *PN = dyn_cast<PHINode>(&I); |
| 100 | + if (!PN) |
| 101 | + break; |
| 102 | + |
| 103 | + SmallPhiSet VisitedPHIs; |
| 104 | + if (isNonNullOrPoisonPhi(&VisitedPHIs, PN)) |
| 105 | + NonNullOrPoisonValues.insert(PN); |
| 106 | + } |
| 107 | + } |
| 108 | + |
| 109 | + for (auto &BB : F) { |
| 110 | + // This could also be extended to handle SwitchInst, but using a SwitchInst |
| 111 | + // for a null check seems unlikely. |
| 112 | + auto *BI = dyn_cast<BranchInst>(BB.getTerminator()); |
| 113 | + if (!BI || BI->isUnconditional()) |
| 114 | + continue; |
| 115 | + |
| 116 | + // The first operand of a conditional branch is the condition. |
| 117 | + auto result = findNullCheck(&BI->getOperandUse(0)); |
| 118 | + if (!result.first) |
| 119 | + continue; |
| 120 | + assert((result.second == NullCheckEq || result.second == NullCheckNe) && |
| 121 | + "valid null check kind expected if ICmpInst was found"); |
| 122 | + |
| 123 | + BasicBlock *NonNullBB; |
| 124 | + if (result.second == NullCheckEq) { |
| 125 | + // If the comparison instruction is checking for equaliity with null, |
| 126 | + // then the pointer is nonnull on the `false` branch. |
| 127 | + NonNullBB = BI->getSuccessor(1); |
| 128 | + } else { |
| 129 | + // Otherwise, if the comparison instruction is checking for inequality |
| 130 | + // with null, the pointer is nonnull on the `true` branch. |
| 131 | + NonNullBB = BI->getSuccessor(0); |
| 132 | + } |
| 133 | + |
| 134 | + Use *U = result.first; |
| 135 | + ICmpInst *CI = cast<ICmpInst>(U->get()); |
| 136 | + unsigned nonConstantIndex; |
| 137 | + if (isa<Constant>(CI->getOperand(0))) |
| 138 | + nonConstantIndex = 1; |
| 139 | + else |
| 140 | + nonConstantIndex = 0; |
| 141 | + |
| 142 | + // Due to the semantics of poison values in LLVM, we have to check that |
| 143 | + // there is actually some externally visible side effect that is dependent |
| 144 | + // on the poison value. Since poison values are otherwise treated as undef, |
| 145 | + // and a load of undef is undefined behavior (which is externally visible), |
| 146 | + // it suffices to look for a load of the nonnull-or-poison value. |
| 147 | + // |
| 148 | + // This could be extended to any block control-dependent on this branch of |
| 149 | + // the null check, it's unclear if that will actually catch more cases in |
| 150 | + // real code. |
| 151 | + Value *PtrV = CI->getOperand(nonConstantIndex); |
| 152 | + if (blockContainsLoadDerivedFrom(NonNullBB, PtrV)) { |
| 153 | + Type *BoolTy = CI->getType(); |
| 154 | + Value *NewV = ConstantInt::get(BoolTy, result.second == NullCheckNe); |
| 155 | + U->set(NewV); |
| 156 | + } |
| 157 | + } |
| 158 | + |
| 159 | + NonNullOrPoisonValues.clear(); |
| 160 | + |
| 161 | + return Changed; |
| 162 | +} |
| 163 | + |
| 164 | +/// Checks whether a phi is derived from known nonnnull-or-poison values, |
| 165 | +/// including other phis that are derived from the same. May return `false` |
| 166 | +/// conservatively in some cases, e.g. if exploring a large cycle of phis. |
| 167 | +bool |
| 168 | +NullCheckElimination::isNonNullOrPoisonPhi(SmallPhiSet *VisitedPhis, |
| 169 | + PHINode *PN) { |
| 170 | + // If we've already seen this phi, return `true`, even though it may not be |
| 171 | + // nonnull, since some other operand in a cycle of phis may invalidate the |
| 172 | + // optimistic assumption that the entire cycle is nonnull, including this phi. |
| 173 | + if (!VisitedPhis->insert(PN)) |
| 174 | + return true; |
| 175 | + |
| 176 | + // Use a sensible limit to avoid iterating over long chains of phis that are |
| 177 | + // unlikely to be nonnull. |
| 178 | + if (VisitedPhis->size() >= kPhiLimit) |
| 179 | + return false; |
| 180 | + |
| 181 | + unsigned numOperands = PN->getNumOperands(); |
| 182 | + for (unsigned i = 0; i < numOperands; ++i) { |
| 183 | + Value *SrcValue = PN->getOperand(i); |
| 184 | + if (NonNullOrPoisonValues.count(SrcValue)) { |
| 185 | + continue; |
| 186 | + } else if (auto *SrcPN = dyn_cast<PHINode>(SrcValue)) { |
| 187 | + if (!isNonNullOrPoisonPhi(VisitedPhis, SrcPN)) |
| 188 | + return false; |
| 189 | + } else { |
| 190 | + return false; |
| 191 | + } |
| 192 | + } |
| 193 | + |
| 194 | + return true; |
| 195 | +} |
| 196 | + |
| 197 | +/// Determines whether an ICmpInst is a null check of a known nonnull-or-poison |
| 198 | +/// value. |
| 199 | +NullCheckElimination::NullCheckResult |
| 200 | +NullCheckElimination::isCmpNullCheck(ICmpInst *CI) { |
| 201 | + if (!CI->isEquality()) |
| 202 | + return NotNullCheck; |
| 203 | + |
| 204 | + unsigned constantIndex; |
| 205 | + if (NonNullOrPoisonValues.count(CI->getOperand(0))) |
| 206 | + constantIndex = 1; |
| 207 | + else if (NonNullOrPoisonValues.count(CI->getOperand(1))) |
| 208 | + constantIndex = 0; |
| 209 | + else |
| 210 | + return NotNullCheck; |
| 211 | + |
| 212 | + auto *C = dyn_cast<Constant>(CI->getOperand(constantIndex)); |
| 213 | + if (!C || !C->isZeroValue()) |
| 214 | + return NotNullCheck; |
| 215 | + |
| 216 | + return |
| 217 | + CI->getPredicate() == llvm::CmpInst::ICMP_EQ ? NullCheckEq : NullCheckNe; |
| 218 | +} |
| 219 | + |
| 220 | +/// Finds the Use, if any, of an ICmpInst null check of a nonnull-or-poison |
| 221 | +/// value. |
| 222 | +std::pair<Use*, NullCheckElimination::NullCheckResult> |
| 223 | +NullCheckElimination::findNullCheck(Use *U) { |
| 224 | + auto *I = dyn_cast<Instruction>(U->get()); |
| 225 | + if (!I) |
| 226 | + return std::make_pair(nullptr, NotNullCheck); |
| 227 | + |
| 228 | + if (auto *CI = dyn_cast<ICmpInst>(I)) { |
| 229 | + NullCheckResult result = isCmpNullCheck(CI); |
| 230 | + if (result == NotNullCheck) |
| 231 | + return std::make_pair(nullptr, NotNullCheck); |
| 232 | + else |
| 233 | + return std::make_pair(U, result); |
| 234 | + } |
| 235 | + |
| 236 | + unsigned opcode = I->getOpcode(); |
| 237 | + if (opcode == Instruction::Or || opcode == Instruction::And) { |
| 238 | + auto result = findNullCheck(&I->getOperandUse(0)); |
| 239 | + if (result.second == NotNullCheck) |
| 240 | + return findNullCheck(&I->getOperandUse(1)); |
| 241 | + else |
| 242 | + return result; |
| 243 | + } |
| 244 | + |
| 245 | + return std::make_pair(nullptr, NotNullCheck); |
| 246 | +} |
| 247 | + |
| 248 | +/// Determines whether `BB` contains a load from `PtrV`, or any inbounds GEP |
| 249 | +/// derived from `PtrV`. |
| 250 | +bool |
| 251 | +NullCheckElimination::blockContainsLoadDerivedFrom(BasicBlock *BB, |
| 252 | + Value *PtrV) { |
| 253 | + for (auto &I : *BB) { |
| 254 | + auto *LI = dyn_cast<LoadInst>(&I); |
| 255 | + if (!LI) |
| 256 | + continue; |
| 257 | + |
| 258 | + Value *V = LI->getPointerOperand(); |
| 259 | + while (NonNullOrPoisonValues.count(V)) { |
| 260 | + if (V == PtrV) |
| 261 | + return true; |
| 262 | + |
| 263 | + auto *GEP = dyn_cast<GetElementPtrInst>(V); |
| 264 | + if (!GEP) |
| 265 | + break; |
| 266 | + |
| 267 | + V = GEP->getOperand(0); |
| 268 | + } |
| 269 | + } |
| 270 | + |
| 271 | + return false; |
| 272 | +} |
| 273 | + |
0 commit comments