Skip to content

Commit b7e2f14

Browse files
authored
Fix nondeterministic cache errors (rust-lang#237)
* Fixing internal nondeterminacy * Fixing internal nondeterminacy * Fix load minCut
1 parent ae09f3d commit b7e2f14

32 files changed

+874
-556
lines changed

Diff for: enzyme/Enzyme/AdjointGenerator.h

+116-136
Large diffs are not rendered by default.

Diff for: enzyme/Enzyme/CacheUtility.cpp

+11-4
Original file line numberDiff line numberDiff line change
@@ -728,8 +728,14 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T,
728728
ConstantInt::get(Type::getInt64Ty(T->getContext()), 3));
729729
}
730730
if (extraSize && i == 0) {
731-
Value *es = unwrapM(extraSize, allocationBuilder,
732-
/*available*/ ValueToValueMapTy(),
731+
ValueToValueMapTy available;
732+
for (auto &sl : sublimits) {
733+
for (auto &cl : sl.second) {
734+
if (cl.first.var)
735+
available[cl.first.var] = cl.first.var;
736+
}
737+
}
738+
Value *es = unwrapM(extraSize, allocationBuilder, available,
733739
UnwrapMode::AttemptFullUnwrapWithLookup);
734740
assert(es);
735741
size = allocationBuilder.CreateMul(size, es, "", /*NUW*/ true,
@@ -1346,8 +1352,9 @@ Value *CacheUtility::getCachePointer(bool inForwardPass, IRBuilder<> &BuilderM,
13461352
8);
13471353
cast<LoadInst>(next)->setMetadata(
13481354
LLVMContext::MD_dereferenceable,
1349-
MDNode::get(cache->getContext(),
1350-
{ConstantAsMetadata::get(byteSizeOfType)}));
1355+
MDNode::get(
1356+
cache->getContext(),
1357+
ArrayRef<Metadata *>(ConstantAsMetadata::get(byteSizeOfType))));
13511358
unsigned bsize = (unsigned)byteSizeOfType->getZExtValue();
13521359
if ((bsize & (bsize - 1)) == 0) {
13531360
#if LLVM_VERSION_MAJOR >= 10

Diff for: enzyme/Enzyme/CacheUtility.h

+7-1
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,11 @@ static inline bool operator==(const LoopContext &lhs, const LoopContext &rhs) {
8585
enum class UnwrapMode {
8686
// It is already known that it is legal to fully unwrap
8787
// this instruction. This means unwrap this instruction,
88-
// its operands, etc
88+
// its operands, etc. However, this will stop at known
89+
// cached available from a tape.
8990
LegalFullUnwrap,
91+
// Unlike LegalFullUnwrap, this will unwrap through a tape
92+
LegalFullUnwrapNoTapeReplace,
9093
// Attempt to fully unwrap this, looking up whenever it
9194
// is not legal to unwrap
9295
AttemptFullUnwrapWithLookup,
@@ -102,6 +105,9 @@ static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
102105
case UnwrapMode::LegalFullUnwrap:
103106
os << "LegalFullUnwrap";
104107
break;
108+
case UnwrapMode::LegalFullUnwrapNoTapeReplace:
109+
os << "LegalFullUnwrapNoTapeReplace";
110+
break;
105111
case UnwrapMode::AttemptFullUnwrapWithLookup:
106112
os << "AttemptFullUnwrapWithLookup";
107113
break;

Diff for: enzyme/Enzyme/DifferentialUseAnalysis.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,12 @@ static inline bool is_value_needed_in_reverse(
241241
// One may need to this value in the computation of loop
242242
// bounds/comparisons/etc (which even though not active -- will be used for
243243
// the reverse pass)
244-
// We only need this if we're not doing the combined forward/reverse since
244+
// We could potentially optimize this to avoid caching if in combined mode
245+
// and the instruction dominates all returns
245246
// otherwise it will use the local cache (rather than save for a separate
246247
// backwards cache)
247248
// We also don't need this if looking at the shadow rather than primal
248-
if (mode != DerivativeMode::ReverseModeCombined) {
249+
{
249250
// Proving that none of the uses (or uses' uses) are used in control flow
250251
// allows us to safely not do this load
251252

Diff for: enzyme/Enzyme/Enzyme.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ class Enzyme : public ModulePass {
499499
->getElementType(tapeIdx);
500500
}
501501
if (tapeType &&
502-
DL.getTypeSizeInBits(tapeType) < 8 * allocatedTapeSize) {
502+
DL.getTypeSizeInBits(tapeType) < 8 * (size_t)allocatedTapeSize) {
503503
auto bytes = DL.getTypeSizeInBits(tapeType) / 8;
504504
EmitFailure("Insufficient tape allocation size", CI->getDebugLoc(),
505505
CI, "need ", bytes, " bytes have ", allocatedTapeSize,

Diff for: enzyme/Enzyme/EnzymeLogic.cpp

+88-60
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,12 @@ cl::opt<bool> nonmarkedglobals_inactiveloads(
8585
cl::desc("Consider loads of nonmarked globals to be inactive"));
8686
}
8787

88-
bool is_load_uncacheable(
89-
LoadInst &li, AAResults &AA, Function *oldFunc, TargetLibraryInfo &TLI,
90-
const SmallPtrSetImpl<const Instruction *> &unnecessaryInstructions,
91-
const std::map<Argument *, bool> &uncacheable_args, DerivativeMode mode);
92-
9388
struct CacheAnalysis {
9489
AAResults &AA;
9590
Function *oldFunc;
9691
ScalarEvolution &SE;
9792
LoopInfo &OrigLI;
98-
DominatorTree &DT;
93+
DominatorTree &OrigDT;
9994
TargetLibraryInfo &TLI;
10095
const SmallPtrSetImpl<const Instruction *> &unnecessaryInstructions;
10196
const std::map<Argument *, bool> &uncacheable_args;
@@ -106,8 +101,8 @@ struct CacheAnalysis {
106101
DominatorTree &OrigDT, TargetLibraryInfo &TLI,
107102
const SmallPtrSetImpl<const Instruction *> &unnecessaryInstructions,
108103
const std::map<Argument *, bool> &uncacheable_args, DerivativeMode mode)
109-
: AA(AA), oldFunc(oldFunc), SE(SE), OrigLI(OrigLI), DT(OrigDT), TLI(TLI),
110-
unnecessaryInstructions(unnecessaryInstructions),
104+
: AA(AA), oldFunc(oldFunc), SE(SE), OrigLI(OrigLI), OrigDT(OrigDT),
105+
TLI(TLI), unnecessaryInstructions(unnecessaryInstructions),
111106
uncacheable_args(uncacheable_args), mode(mode) {}
112107

113108
bool is_value_mustcache_from_origin(Value *obj) {
@@ -288,8 +283,8 @@ struct CacheAnalysis {
288283
auto SH = SExpr->getLoop()->getHeader();
289284
if (auto LExpr = dyn_cast<SCEVAddRecExpr>(lim)) {
290285
auto LH = LExpr->getLoop()->getHeader();
291-
if (SH != LH && !DT.dominates(SH, LH) &&
292-
!DT.dominates(LH, SH)) {
286+
if (SH != LH && !OrigDT.dominates(SH, LH) &&
287+
!OrigDT.dominates(LH, SH)) {
293288
check = false;
294289
}
295290
}
@@ -349,8 +344,8 @@ struct CacheAnalysis {
349344
auto SH = SExpr->getLoop()->getHeader();
350345
if (auto LExpr = dyn_cast<SCEVAddRecExpr>(lim)) {
351346
auto LH = LExpr->getLoop()->getHeader();
352-
if (SH != LH && !DT.dominates(SH, LH) &&
353-
!DT.dominates(LH, SH)) {
347+
if (SH != LH && !OrigDT.dominates(SH, LH) &&
348+
!OrigDT.dominates(LH, SH)) {
354349
check = false;
355350
}
356351
}
@@ -477,15 +472,16 @@ struct CacheAnalysis {
477472

478473
std::map<Argument *, bool>
479474
compute_uncacheable_args_for_one_callsite(CallInst *callsite_op) {
475+
Function *Fn = callsite_op->getCalledFunction();
480476

481-
if (!callsite_op->getCalledFunction())
477+
if (!Fn)
482478
return {};
483479

484-
if (isMemFreeLibMFunction(callsite_op->getCalledFunction()->getName())) {
480+
if (isMemFreeLibMFunction(Fn->getName())) {
485481
return {};
486482
}
487483

488-
if (isCertainMallocOrFree(callsite_op->getCalledFunction())) {
484+
if (isCertainPrintMallocOrFree(Fn)) {
489485
return {};
490486
}
491487
std::vector<Value *> args;
@@ -537,7 +533,7 @@ struct CacheAnalysis {
537533
}
538534
}
539535
}
540-
if (called && isCertainMallocOrFree(called)) {
536+
if (called && isCertainPrintMallocOrFree(called)) {
541537
return false;
542538
}
543539
if (called && isMemFreeLibMFunction(called->getName())) {
@@ -1635,20 +1631,28 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
16351631
}
16361632
}
16371633

1638-
for (const auto &m : gutils->knownRecomputeHeuristic) {
1639-
if (!m.second && !isa<LoadInst>(m.first) && !isa<CallInst>(m.first)) {
1640-
auto newi = gutils->getNewFromOriginal(m.first);
1641-
IRBuilder<> BuilderZ(cast<Instruction>(newi)->getNextNode());
1642-
if (isa<PHINode>(newi)) {
1643-
BuilderZ.SetInsertPoint(
1644-
cast<Instruction>(newi)->getParent()->getFirstNonPHI());
1634+
if (gutils->knownRecomputeHeuristic.size()) {
1635+
// Even though we could simply iterate through the heuristic map,
1636+
// we explicity iterate in order of the instructions to maintain
1637+
// a deterministic cache ordering.
1638+
for (auto &BB : *gutils->oldFunc)
1639+
for (auto &I : BB) {
1640+
auto found = gutils->knownRecomputeHeuristic.find(&I);
1641+
if (found != gutils->knownRecomputeHeuristic.end()) {
1642+
if (!found->second && !isa<CallInst>(&I)) {
1643+
auto newi = gutils->getNewFromOriginal(&I);
1644+
IRBuilder<> BuilderZ(cast<Instruction>(newi)->getNextNode());
1645+
if (isa<PHINode>(newi)) {
1646+
BuilderZ.SetInsertPoint(
1647+
cast<Instruction>(newi)->getParent()->getFirstNonPHI());
1648+
}
1649+
gutils->cacheForReverse(BuilderZ, newi,
1650+
getIndex(&I, CacheType::Self));
1651+
}
1652+
}
16451653
}
1646-
gutils->cacheForReverse(
1647-
BuilderZ, newi,
1648-
getIndex(cast<Instruction>(const_cast<Value *>(m.first)),
1649-
CacheType::Self));
1650-
}
16511654
}
1655+
16521656
auto nf = gutils->newFunc;
16531657

16541658
while (gutils->inversionAllocs->size() > 0) {
@@ -2068,12 +2072,10 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
20682072
user->setCalledFunction(NewF);
20692073
}
20702074
}
2071-
PPC.AlwaysInline(gutils->newFunc);
2075+
PPC.AlwaysInline(NewF);
20722076
auto Arch = llvm::Triple(NewF->getParent()->getTargetTriple()).getArch();
20732077
if (Arch == Triple::nvptx || Arch == Triple::nvptx64)
20742078
PPC.ReplaceReallocs(NewF, /*mem2reg*/ true);
2075-
if (PostOpt)
2076-
PPC.optimizeIntermediate(NewF);
20772079

20782080
AugmentedCachedFunctions.find(tup)->second.fn = NewF;
20792081
if (recursive || (omp && !noTape))
@@ -2087,6 +2089,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
20872089
gutils->newFunc->eraseFromParent();
20882090

20892091
delete gutils;
2092+
if (PostOpt)
2093+
PPC.optimizeIntermediate(NewF);
20902094
if (EnzymePrint)
20912095
llvm::errs() << *NewF << "\n";
20922096
return AugmentedCachedFunctions.find(tup)->second;
@@ -2993,35 +2997,55 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
29932997

29942998
if (mode != DerivativeMode::ReverseModeCombined &&
29952999
mode != DerivativeMode::ForwardMode) {
2996-
std::map<Value *, std::vector<Value *>> unwrapToOrig;
2997-
for (auto pair : gutils->unwrappedLoads)
2998-
unwrapToOrig[pair.second].push_back(const_cast<Value *>(pair.first));
2999-
std::map<Value *, Value *> newIToNextI;
3000+
// One must use this temporary map to first create all the replacements
3001+
// prior to actually replacing to ensure that getSubLimits has the same
3002+
// behavior and unwrap behavior for all replacements.
3003+
std::vector<std::pair<Instruction *, Value *>> newIToNextI;
3004+
30003005
for (const auto &m : mapping) {
3001-
if (m.first.second == CacheType::Self && !isa<LoadInst>(m.first.first) &&
3002-
!isa<CallInst>(m.first.first)) {
3006+
if (m.first.second == CacheType::Self && !isa<CallInst>(m.first.first) &&
3007+
gutils->knownRecomputeHeuristic.count(m.first.first)) {
3008+
assert(gutils->knownRecomputeHeuristic.count(m.first.first));
30033009
auto newi = gutils->getNewFromOriginal(m.first.first);
30043010
if (auto PN = dyn_cast<PHINode>(newi))
3005-
if (gutils->fictiousPHIs.count(PN))
3011+
if (gutils->fictiousPHIs.count(PN)) {
3012+
assert(gutils->fictiousPHIs[PN] == m.first.first);
30063013
gutils->fictiousPHIs.erase(PN);
3014+
}
30073015
IRBuilder<> BuilderZ(newi->getNextNode());
30083016
if (isa<PHINode>(m.first.first)) {
30093017
BuilderZ.SetInsertPoint(
30103018
cast<Instruction>(newi)->getParent()->getFirstNonPHI());
30113019
}
3012-
Value *nexti = gutils->cacheForReverse(BuilderZ, newi, m.second);
3013-
for (auto V : unwrapToOrig[newi]) {
3014-
ValueToValueMapTy empty;
3015-
IRBuilder<> lb(cast<Instruction>(V));
3016-
// This must disallow caching here as otherwise performing the loop in
3017-
// the wrong order may result in first replacing the later unwrapped
3018-
// value, caching it, then attempting to reuse it for an earlier
3019-
// replacement.
3020-
V->replaceAllUsesWith(
3021-
gutils->unwrapM(nexti, lb, empty, UnwrapMode::LegalFullUnwrap,
3022-
/*scope*/ nullptr, /*permitCache*/ false));
3023-
cast<Instruction>(V)->eraseFromParent();
3024-
}
3020+
Value *nexti = gutils->cacheForReverse(
3021+
BuilderZ, newi, m.second, /*ignoreType*/ false, /*replace*/ false);
3022+
newIToNextI.emplace_back(newi, nexti);
3023+
}
3024+
}
3025+
3026+
std::map<Value *, std::vector<Instruction *>> unwrapToOrig;
3027+
for (auto pair : gutils->unwrappedLoads)
3028+
unwrapToOrig[pair.second].push_back(
3029+
const_cast<Instruction *>(pair.first));
3030+
gutils->unwrappedLoads.clear();
3031+
for (auto pair : newIToNextI) {
3032+
auto newi = pair.first;
3033+
auto nexti = pair.second;
3034+
newi->replaceAllUsesWith(nexti);
3035+
gutils->erase(newi);
3036+
for (auto V : unwrapToOrig[newi]) {
3037+
ValueToValueMapTy empty;
3038+
IRBuilder<> lb(V);
3039+
// This must disallow caching here as otherwise performing the loop in
3040+
// the wrong order may result in first replacing the later unwrapped
3041+
// value, caching it, then attempting to reuse it for an earlier
3042+
// replacement.
3043+
Value *nval = gutils->unwrapM(nexti, lb, empty,
3044+
UnwrapMode::LegalFullUnwrapNoTapeReplace,
3045+
/*scope*/ nullptr, /*permitCache*/ false);
3046+
assert(nval);
3047+
V->replaceAllUsesWith(nval);
3048+
V->eraseFromParent();
30253049
}
30263050
}
30273051

@@ -3044,13 +3068,17 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
30443068

30453069
if (auto bi = dyn_cast<BranchInst>(BB.getTerminator())) {
30463070

3047-
Value *vals[1] = {gutils->getNewFromOriginal(bi->getCondition())};
3048-
if (bi->getSuccessor(0) == unreachables[0]) {
3049-
gutils->replaceAWithB(vals[0],
3050-
ConstantInt::getFalse(vals[0]->getContext()));
3051-
} else {
3052-
gutils->replaceAWithB(vals[0],
3053-
ConstantInt::getTrue(vals[0]->getContext()));
3071+
Value *condition = gutils->getNewFromOriginal(bi->getCondition());
3072+
3073+
Constant *repVal = (bi->getSuccessor(0) == unreachables[0])
3074+
? ConstantInt::getFalse(condition->getContext())
3075+
: ConstantInt::getTrue(condition->getContext());
3076+
3077+
for (auto UI = condition->use_begin(), E = condition->use_end();
3078+
UI != E;) {
3079+
Use &U = *UI;
3080+
++UI;
3081+
U.set(repVal);
30543082
}
30553083
}
30563084
}
@@ -3219,11 +3247,11 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
32193247
PPC.AlwaysInline(gutils->newFunc);
32203248
if (Arch == Triple::nvptx || Arch == Triple::nvptx64)
32213249
PPC.ReplaceReallocs(gutils->newFunc, /*mem2reg*/ true);
3222-
if (PostOpt)
3223-
PPC.optimizeIntermediate(gutils->newFunc);
32243250

32253251
auto nf = gutils->newFunc;
32263252
delete gutils;
3253+
if (PostOpt)
3254+
PPC.optimizeIntermediate(nf);
32273255
if (EnzymePrint) {
32283256
llvm::errs() << *nf << "\n";
32293257
}

Diff for: enzyme/Enzyme/FunctionUtils.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ static inline void UpgradeAllocasToMallocs(Function *NewF,
307307
if (auto C = dyn_cast<CastInst>(rep))
308308
CI = cast<CallInst>(C->getOperand(0));
309309
CI->setMetadata("enzyme_fromstack", MDNode::get(CI->getContext(), {}));
310+
CI->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
310311
assert(rep->getType() == AI->getType());
311312
AI->replaceAllUsesWith(rep);
312313
AI->eraseFromParent();
@@ -527,7 +528,7 @@ void PreProcessCache::ReplaceReallocs(Function *NewF, bool mem2reg) {
527528

528529
std::vector<CallInst *> ToConvert;
529530
std::map<CallInst *, Value *> reallocSizes;
530-
IntegerType *T;
531+
IntegerType *T = nullptr;
531532

532533
for (auto &BB : *NewF) {
533534
for (auto &I : BB) {
@@ -547,6 +548,7 @@ void PreProcessCache::ReplaceReallocs(Function *NewF, bool mem2reg) {
547548
std::vector<AllocaInst *> memoryLocations;
548549

549550
for (auto CI : ToConvert) {
551+
assert(T);
550552
AllocaInst *AI =
551553
OldAllocationSize(CI->getArgOperand(0), CI, NewF, T, reallocSizes);
552554

0 commit comments

Comments
 (0)