Skip to content
This repository was archived by the owner on Feb 5, 2019. It is now read-only.

Commit dd684dc

Browse files
author
Cameron Zwarich
committed
Add a NullCheckElimination pass
This pass is not Rust-specific, in that all of its transformations are intended to be correct for arbitrary LLVM IR, but it targets idioms found in IR generated by `rustc`, e.g. heavy use of `inbounds` GEPs.
1 parent 1bba097 commit dd684dc

File tree

7 files changed

+453
-0
lines changed

7 files changed

+453
-0
lines changed

include/llvm/InitializePasses.h

+3
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,9 @@ void initializeBBVectorizePass(PassRegistry&);
275275
void initializeMachineFunctionPrinterPassPass(PassRegistry&);
276276
void initializeStackMapLivenessPass(PassRegistry&);
277277
void initializeLoadCombinePass(PassRegistry&);
278+
279+
// Specific to the rust-lang llvm branch:
280+
void initializeNullCheckEliminationPass(PassRegistry&);
278281
}
279282

280283
#endif

include/llvm/LinkAllPasses.h

+3
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ namespace {
160160
(void) llvm::createScalarizerPass();
161161
(void) llvm::createSeparateConstOffsetFromGEPPass();
162162

163+
// Specific to the rust-lang llvm branch:
164+
(void) llvm::createNullCheckEliminationPass();
165+
163166
(void)new llvm::IntervalPartition();
164167
(void)new llvm::FindUsedTypes();
165168
(void)new llvm::ScalarEvolution();

include/llvm/Transforms/Scalar.h

+7
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,13 @@ FunctionPass *createSeparateConstOffsetFromGEPPass();
388388
//
389389
BasicBlockPass *createLoadCombinePass();
390390

391+
// Specific to the rust-lang llvm branch:
392+
//===----------------------------------------------------------------------===//
393+
//
394+
// NullCheckElimination - Eliminate null checks.
395+
//
396+
FunctionPass *createNullCheckEliminationPass();
397+
391398
} // End llvm namespace
392399

393400
#endif

lib/Transforms/Scalar/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ add_llvm_library(LLVMScalarOpts
2222
LoopUnswitch.cpp
2323
LowerAtomic.cpp
2424
MemCpyOptimizer.cpp
25+
NullCheckElimination.cpp
2526
PartiallyInlineLibCalls.cpp
2627
Reassociate.cpp
2728
Reg2Mem.cpp
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
//===-- NullCheckElimination.cpp - Null Check Elimination Pass ------------===//
2+
//
3+
// The LLVM Compiler Infrastructure
4+
//
5+
// This file is distributed under the University of Illinois Open Source
6+
// License. See LICENSE.TXT for details.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "llvm/Transforms/Scalar.h"
11+
#include "llvm/ADT/DenseSet.h"
12+
#include "llvm/ADT/SmallPtrSet.h"
13+
#include "llvm/ADT/Statistic.h"
14+
#include "llvm/IR/Constants.h"
15+
#include "llvm/IR/Function.h"
16+
#include "llvm/IR/Instructions.h"
17+
#include "llvm/Pass.h"
18+
using namespace llvm;
19+
20+
#define DEBUG_TYPE "null-check-elimination"
21+
22+
namespace {
23+
struct NullCheckElimination : public FunctionPass {
24+
static char ID;
25+
NullCheckElimination() : FunctionPass(ID) {
26+
initializeNullCheckEliminationPass(*PassRegistry::getPassRegistry());
27+
}
28+
bool runOnFunction(Function &F) override;
29+
30+
void getAnalysisUsage(AnalysisUsage &AU) const override {
31+
AU.setPreservesCFG();
32+
}
33+
34+
private:
35+
static const unsigned kPhiLimit = 16;
36+
typedef SmallPtrSet<PHINode*, kPhiLimit> SmallPhiSet;
37+
enum NullCheckResult {
38+
NotNullCheck,
39+
NullCheckEq,
40+
NullCheckNe,
41+
};
42+
43+
bool isNonNullOrPoisonPhi(SmallPhiSet *VisitedPhis, PHINode*);
44+
45+
NullCheckResult isCmpNullCheck(ICmpInst*);
46+
std::pair<Use*, NullCheckResult> findNullCheck(Use*);
47+
48+
bool blockContainsLoadDerivedFrom(BasicBlock*, Value*);
49+
50+
DenseSet<Value*> NonNullOrPoisonValues;
51+
};
52+
}
53+
54+
char NullCheckElimination::ID = 0;
55+
INITIALIZE_PASS_BEGIN(NullCheckElimination,
56+
"null-check-elimination",
57+
"Null Check Elimination",
58+
false, false)
59+
INITIALIZE_PASS_END(NullCheckElimination,
60+
"null-check-elimination",
61+
"Null Check Elimination",
62+
false, false)
63+
64+
FunctionPass *llvm::createNullCheckEliminationPass() {
65+
return new NullCheckElimination();
66+
}
67+
68+
bool NullCheckElimination::runOnFunction(Function &F) {
69+
if (skipOptnoneFunction(F))
70+
return false;
71+
72+
bool Changed = false;
73+
74+
// Collect argumetns with the `nonnull` attribute.
75+
for (auto &Arg : F.args()) {
76+
if (Arg.hasNonNullAttr())
77+
NonNullOrPoisonValues.insert(&Arg);
78+
}
79+
80+
// Collect instructions that definitely produce nonnull-or-poison values.
81+
// At the moment, this is restricted to inbounds GEPs. It would be slightly
82+
// more difficult to include uses of values dominated by a null check, since
83+
// then we would have to consider uses instead of mere values.
84+
for (auto &BB : F) {
85+
for (auto &I : BB) {
86+
if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) {
87+
if (GEP->isInBounds()) {
88+
NonNullOrPoisonValues.insert(GEP);
89+
}
90+
}
91+
}
92+
}
93+
94+
// Find phis that are derived entirely from nonnull-or-poison values,
95+
// including other phis that are themselves derived entirely from these
96+
// values.
97+
for (auto &BB : F) {
98+
for (auto &I : BB) {
99+
auto *PN = dyn_cast<PHINode>(&I);
100+
if (!PN)
101+
break;
102+
103+
SmallPhiSet VisitedPHIs;
104+
if (isNonNullOrPoisonPhi(&VisitedPHIs, PN))
105+
NonNullOrPoisonValues.insert(PN);
106+
}
107+
}
108+
109+
for (auto &BB : F) {
110+
// This could also be extended to handle SwitchInst, but using a SwitchInst
111+
// for a null check seems unlikely.
112+
auto *BI = dyn_cast<BranchInst>(BB.getTerminator());
113+
if (!BI || BI->isUnconditional())
114+
continue;
115+
116+
// The first operand of a conditional branch is the condition.
117+
auto result = findNullCheck(&BI->getOperandUse(0));
118+
if (!result.first)
119+
continue;
120+
assert((result.second == NullCheckEq || result.second == NullCheckNe) &&
121+
"valid null check kind expected if ICmpInst was found");
122+
123+
BasicBlock *NonNullBB;
124+
if (result.second == NullCheckEq) {
125+
// If the comparison instruction is checking for equaliity with null,
126+
// then the pointer is nonnull on the `false` branch.
127+
NonNullBB = BI->getSuccessor(1);
128+
} else {
129+
// Otherwise, if the comparison instruction is checking for inequality
130+
// with null, the pointer is nonnull on the `true` branch.
131+
NonNullBB = BI->getSuccessor(0);
132+
}
133+
134+
Use *U = result.first;
135+
ICmpInst *CI = cast<ICmpInst>(U->get());
136+
unsigned nonConstantIndex;
137+
if (isa<Constant>(CI->getOperand(0)))
138+
nonConstantIndex = 1;
139+
else
140+
nonConstantIndex = 0;
141+
142+
// Due to the semantics of poison values in LLVM, we have to check that
143+
// there is actually some externally visible side effect that is dependent
144+
// on the poison value. Since poison values are otherwise treated as undef,
145+
// and a load of undef is undefined behavior (which is externally visible),
146+
// it suffices to look for a load of the nonnull-or-poison value.
147+
//
148+
// This could be extended to any block control-dependent on this branch of
149+
// the null check, it's unclear if that will actually catch more cases in
150+
// real code.
151+
Value *PtrV = CI->getOperand(nonConstantIndex);
152+
if (blockContainsLoadDerivedFrom(NonNullBB, PtrV)) {
153+
Type *BoolTy = CI->getType();
154+
Value *NewV = ConstantInt::get(BoolTy, result.second == NullCheckNe);
155+
U->set(NewV);
156+
}
157+
}
158+
159+
NonNullOrPoisonValues.clear();
160+
161+
return Changed;
162+
}
163+
164+
/// Checks whether a phi is derived from known nonnnull-or-poison values,
165+
/// including other phis that are derived from the same. May return `false`
166+
/// conservatively in some cases, e.g. if exploring a large cycle of phis.
167+
bool
168+
NullCheckElimination::isNonNullOrPoisonPhi(SmallPhiSet *VisitedPhis,
169+
PHINode *PN) {
170+
// If we've already seen this phi, return `true`, even though it may not be
171+
// nonnull, since some other operand in a cycle of phis may invalidate the
172+
// optimistic assumption that the entire cycle is nonnull, including this phi.
173+
if (!VisitedPhis->insert(PN))
174+
return true;
175+
176+
// Use a sensible limit to avoid iterating over long chains of phis that are
177+
// unlikely to be nonnull.
178+
if (VisitedPhis->size() >= kPhiLimit)
179+
return false;
180+
181+
unsigned numOperands = PN->getNumOperands();
182+
for (unsigned i = 0; i < numOperands; ++i) {
183+
Value *SrcValue = PN->getOperand(i);
184+
if (NonNullOrPoisonValues.count(SrcValue)) {
185+
continue;
186+
} else if (auto *SrcPN = dyn_cast<PHINode>(SrcValue)) {
187+
if (!isNonNullOrPoisonPhi(VisitedPhis, SrcPN))
188+
return false;
189+
} else {
190+
return false;
191+
}
192+
}
193+
194+
return true;
195+
}
196+
197+
/// Determines whether an ICmpInst is a null check of a known nonnull-or-poison
198+
/// value.
199+
NullCheckElimination::NullCheckResult
200+
NullCheckElimination::isCmpNullCheck(ICmpInst *CI) {
201+
if (!CI->isEquality())
202+
return NotNullCheck;
203+
204+
unsigned constantIndex;
205+
if (NonNullOrPoisonValues.count(CI->getOperand(0)))
206+
constantIndex = 1;
207+
else if (NonNullOrPoisonValues.count(CI->getOperand(1)))
208+
constantIndex = 0;
209+
else
210+
return NotNullCheck;
211+
212+
auto *C = dyn_cast<Constant>(CI->getOperand(constantIndex));
213+
if (!C || !C->isZeroValue())
214+
return NotNullCheck;
215+
216+
return
217+
CI->getPredicate() == llvm::CmpInst::ICMP_EQ ? NullCheckEq : NullCheckNe;
218+
}
219+
220+
/// Finds the Use, if any, of an ICmpInst null check of a nonnull-or-poison
221+
/// value.
222+
std::pair<Use*, NullCheckElimination::NullCheckResult>
223+
NullCheckElimination::findNullCheck(Use *U) {
224+
auto *I = dyn_cast<Instruction>(U->get());
225+
if (!I)
226+
return std::make_pair(nullptr, NotNullCheck);
227+
228+
if (auto *CI = dyn_cast<ICmpInst>(I)) {
229+
NullCheckResult result = isCmpNullCheck(CI);
230+
if (result == NotNullCheck)
231+
return std::make_pair(nullptr, NotNullCheck);
232+
else
233+
return std::make_pair(U, result);
234+
}
235+
236+
unsigned opcode = I->getOpcode();
237+
if (opcode == Instruction::Or || opcode == Instruction::And) {
238+
auto result = findNullCheck(&I->getOperandUse(0));
239+
if (result.second == NotNullCheck)
240+
return findNullCheck(&I->getOperandUse(1));
241+
else
242+
return result;
243+
}
244+
245+
return std::make_pair(nullptr, NotNullCheck);
246+
}
247+
248+
/// Determines whether `BB` contains a load from `PtrV`, or any inbounds GEP
249+
/// derived from `PtrV`.
250+
bool
251+
NullCheckElimination::blockContainsLoadDerivedFrom(BasicBlock *BB,
252+
Value *PtrV) {
253+
for (auto &I : *BB) {
254+
auto *LI = dyn_cast<LoadInst>(&I);
255+
if (!LI)
256+
continue;
257+
258+
Value *V = LI->getPointerOperand();
259+
while (NonNullOrPoisonValues.count(V)) {
260+
if (V == PtrV)
261+
return true;
262+
263+
auto *GEP = dyn_cast<GetElementPtrInst>(V);
264+
if (!GEP)
265+
break;
266+
267+
V = GEP->getOperand(0);
268+
}
269+
}
270+
271+
return false;
272+
}
273+

lib/Transforms/Scalar/Scalar.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) {
6666
initializeTailCallElimPass(Registry);
6767
initializeSeparateConstOffsetFromGEPPass(Registry);
6868
initializeLoadCombinePass(Registry);
69+
initializeNullCheckEliminationPass(Registry);
6970
}
7071

7172
void LLVMInitializeScalarOpts(LLVMPassRegistryRef R) {

0 commit comments

Comments
 (0)