@@ -85,17 +85,12 @@ cl::opt<bool> nonmarkedglobals_inactiveloads(
85
85
cl::desc(" Consider loads of nonmarked globals to be inactive" ));
86
86
}
87
87
88
- bool is_load_uncacheable (
89
- LoadInst &li, AAResults &AA, Function *oldFunc, TargetLibraryInfo &TLI,
90
- const SmallPtrSetImpl<const Instruction *> &unnecessaryInstructions,
91
- const std::map<Argument *, bool > &uncacheable_args, DerivativeMode mode);
92
-
93
88
struct CacheAnalysis {
94
89
AAResults &AA;
95
90
Function *oldFunc;
96
91
ScalarEvolution &SE;
97
92
LoopInfo &OrigLI;
98
- DominatorTree &DT ;
93
+ DominatorTree &OrigDT ;
99
94
TargetLibraryInfo &TLI;
100
95
const SmallPtrSetImpl<const Instruction *> &unnecessaryInstructions;
101
96
const std::map<Argument *, bool > &uncacheable_args;
@@ -106,8 +101,8 @@ struct CacheAnalysis {
106
101
DominatorTree &OrigDT, TargetLibraryInfo &TLI,
107
102
const SmallPtrSetImpl<const Instruction *> &unnecessaryInstructions,
108
103
const std::map<Argument *, bool > &uncacheable_args, DerivativeMode mode)
109
- : AA(AA), oldFunc(oldFunc), SE(SE), OrigLI(OrigLI), DT( OrigDT), TLI(TLI ),
110
- unnecessaryInstructions (unnecessaryInstructions),
104
+ : AA(AA), oldFunc(oldFunc), SE(SE), OrigLI(OrigLI), OrigDT(OrigDT ),
105
+ TLI (TLI), unnecessaryInstructions(unnecessaryInstructions),
111
106
uncacheable_args(uncacheable_args), mode(mode) {}
112
107
113
108
bool is_value_mustcache_from_origin (Value *obj) {
@@ -288,8 +283,8 @@ struct CacheAnalysis {
288
283
auto SH = SExpr->getLoop ()->getHeader ();
289
284
if (auto LExpr = dyn_cast<SCEVAddRecExpr>(lim)) {
290
285
auto LH = LExpr->getLoop ()->getHeader ();
291
- if (SH != LH && !DT .dominates (SH, LH) &&
292
- !DT .dominates (LH, SH)) {
286
+ if (SH != LH && !OrigDT .dominates (SH, LH) &&
287
+ !OrigDT .dominates (LH, SH)) {
293
288
check = false ;
294
289
}
295
290
}
@@ -349,8 +344,8 @@ struct CacheAnalysis {
349
344
auto SH = SExpr->getLoop ()->getHeader ();
350
345
if (auto LExpr = dyn_cast<SCEVAddRecExpr>(lim)) {
351
346
auto LH = LExpr->getLoop ()->getHeader ();
352
- if (SH != LH && !DT .dominates (SH, LH) &&
353
- !DT .dominates (LH, SH)) {
347
+ if (SH != LH && !OrigDT .dominates (SH, LH) &&
348
+ !OrigDT .dominates (LH, SH)) {
354
349
check = false ;
355
350
}
356
351
}
@@ -477,15 +472,16 @@ struct CacheAnalysis {
477
472
478
473
std::map<Argument *, bool >
479
474
compute_uncacheable_args_for_one_callsite (CallInst *callsite_op) {
475
+ Function *Fn = callsite_op->getCalledFunction ();
480
476
481
- if (!callsite_op-> getCalledFunction () )
477
+ if (!Fn )
482
478
return {};
483
479
484
- if (isMemFreeLibMFunction (callsite_op-> getCalledFunction () ->getName ())) {
480
+ if (isMemFreeLibMFunction (Fn ->getName ())) {
485
481
return {};
486
482
}
487
483
488
- if (isCertainMallocOrFree (callsite_op-> getCalledFunction () )) {
484
+ if (isCertainPrintMallocOrFree (Fn )) {
489
485
return {};
490
486
}
491
487
std::vector<Value *> args;
@@ -537,7 +533,7 @@ struct CacheAnalysis {
537
533
}
538
534
}
539
535
}
540
- if (called && isCertainMallocOrFree (called)) {
536
+ if (called && isCertainPrintMallocOrFree (called)) {
541
537
return false ;
542
538
}
543
539
if (called && isMemFreeLibMFunction (called->getName ())) {
@@ -1635,20 +1631,28 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
1635
1631
}
1636
1632
}
1637
1633
1638
- for (const auto &m : gutils->knownRecomputeHeuristic ) {
1639
- if (!m.second && !isa<LoadInst>(m.first ) && !isa<CallInst>(m.first )) {
1640
- auto newi = gutils->getNewFromOriginal (m.first );
1641
- IRBuilder<> BuilderZ (cast<Instruction>(newi)->getNextNode ());
1642
- if (isa<PHINode>(newi)) {
1643
- BuilderZ.SetInsertPoint (
1644
- cast<Instruction>(newi)->getParent ()->getFirstNonPHI ());
1634
+ if (gutils->knownRecomputeHeuristic .size ()) {
1635
+ // Even though we could simply iterate through the heuristic map,
1636
+ // we explicity iterate in order of the instructions to maintain
1637
+ // a deterministic cache ordering.
1638
+ for (auto &BB : *gutils->oldFunc )
1639
+ for (auto &I : BB) {
1640
+ auto found = gutils->knownRecomputeHeuristic .find (&I);
1641
+ if (found != gutils->knownRecomputeHeuristic .end ()) {
1642
+ if (!found->second && !isa<CallInst>(&I)) {
1643
+ auto newi = gutils->getNewFromOriginal (&I);
1644
+ IRBuilder<> BuilderZ (cast<Instruction>(newi)->getNextNode ());
1645
+ if (isa<PHINode>(newi)) {
1646
+ BuilderZ.SetInsertPoint (
1647
+ cast<Instruction>(newi)->getParent ()->getFirstNonPHI ());
1648
+ }
1649
+ gutils->cacheForReverse (BuilderZ, newi,
1650
+ getIndex (&I, CacheType::Self));
1651
+ }
1652
+ }
1645
1653
}
1646
- gutils->cacheForReverse (
1647
- BuilderZ, newi,
1648
- getIndex (cast<Instruction>(const_cast <Value *>(m.first )),
1649
- CacheType::Self));
1650
- }
1651
1654
}
1655
+
1652
1656
auto nf = gutils->newFunc ;
1653
1657
1654
1658
while (gutils->inversionAllocs ->size () > 0 ) {
@@ -2068,12 +2072,10 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
2068
2072
user->setCalledFunction (NewF);
2069
2073
}
2070
2074
}
2071
- PPC.AlwaysInline (gutils-> newFunc );
2075
+ PPC.AlwaysInline (NewF );
2072
2076
auto Arch = llvm::Triple (NewF->getParent ()->getTargetTriple ()).getArch ();
2073
2077
if (Arch == Triple::nvptx || Arch == Triple::nvptx64)
2074
2078
PPC.ReplaceReallocs (NewF, /* mem2reg*/ true );
2075
- if (PostOpt)
2076
- PPC.optimizeIntermediate (NewF);
2077
2079
2078
2080
AugmentedCachedFunctions.find (tup)->second .fn = NewF;
2079
2081
if (recursive || (omp && !noTape))
@@ -2087,6 +2089,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
2087
2089
gutils->newFunc ->eraseFromParent ();
2088
2090
2089
2091
delete gutils;
2092
+ if (PostOpt)
2093
+ PPC.optimizeIntermediate (NewF);
2090
2094
if (EnzymePrint)
2091
2095
llvm::errs () << *NewF << " \n " ;
2092
2096
return AugmentedCachedFunctions.find (tup)->second ;
@@ -2993,35 +2997,55 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
2993
2997
2994
2998
if (mode != DerivativeMode::ReverseModeCombined &&
2995
2999
mode != DerivativeMode::ForwardMode) {
2996
- std::map<Value *, std::vector<Value *>> unwrapToOrig;
2997
- for (auto pair : gutils->unwrappedLoads )
2998
- unwrapToOrig[pair.second ].push_back (const_cast <Value *>(pair.first ));
2999
- std::map<Value *, Value *> newIToNextI;
3000
+ // One must use this temporary map to first create all the replacements
3001
+ // prior to actually replacing to ensure that getSubLimits has the same
3002
+ // behavior and unwrap behavior for all replacements.
3003
+ std::vector<std::pair<Instruction *, Value *>> newIToNextI;
3004
+
3000
3005
for (const auto &m : mapping) {
3001
- if (m.first .second == CacheType::Self && !isa<LoadInst>(m.first .first ) &&
3002
- !isa<CallInst>(m.first .first )) {
3006
+ if (m.first .second == CacheType::Self && !isa<CallInst>(m.first .first ) &&
3007
+ gutils->knownRecomputeHeuristic .count (m.first .first )) {
3008
+ assert (gutils->knownRecomputeHeuristic .count (m.first .first ));
3003
3009
auto newi = gutils->getNewFromOriginal (m.first .first );
3004
3010
if (auto PN = dyn_cast<PHINode>(newi))
3005
- if (gutils->fictiousPHIs .count (PN))
3011
+ if (gutils->fictiousPHIs .count (PN)) {
3012
+ assert (gutils->fictiousPHIs [PN] == m.first .first );
3006
3013
gutils->fictiousPHIs .erase (PN);
3014
+ }
3007
3015
IRBuilder<> BuilderZ (newi->getNextNode ());
3008
3016
if (isa<PHINode>(m.first .first )) {
3009
3017
BuilderZ.SetInsertPoint (
3010
3018
cast<Instruction>(newi)->getParent ()->getFirstNonPHI ());
3011
3019
}
3012
- Value *nexti = gutils->cacheForReverse (BuilderZ, newi, m.second );
3013
- for (auto V : unwrapToOrig[newi]) {
3014
- ValueToValueMapTy empty;
3015
- IRBuilder<> lb (cast<Instruction>(V));
3016
- // This must disallow caching here as otherwise performing the loop in
3017
- // the wrong order may result in first replacing the later unwrapped
3018
- // value, caching it, then attempting to reuse it for an earlier
3019
- // replacement.
3020
- V->replaceAllUsesWith (
3021
- gutils->unwrapM (nexti, lb, empty, UnwrapMode::LegalFullUnwrap,
3022
- /* scope*/ nullptr , /* permitCache*/ false ));
3023
- cast<Instruction>(V)->eraseFromParent ();
3024
- }
3020
+ Value *nexti = gutils->cacheForReverse (
3021
+ BuilderZ, newi, m.second , /* ignoreType*/ false , /* replace*/ false );
3022
+ newIToNextI.emplace_back (newi, nexti);
3023
+ }
3024
+ }
3025
+
3026
+ std::map<Value *, std::vector<Instruction *>> unwrapToOrig;
3027
+ for (auto pair : gutils->unwrappedLoads )
3028
+ unwrapToOrig[pair.second ].push_back (
3029
+ const_cast <Instruction *>(pair.first ));
3030
+ gutils->unwrappedLoads .clear ();
3031
+ for (auto pair : newIToNextI) {
3032
+ auto newi = pair.first ;
3033
+ auto nexti = pair.second ;
3034
+ newi->replaceAllUsesWith (nexti);
3035
+ gutils->erase (newi);
3036
+ for (auto V : unwrapToOrig[newi]) {
3037
+ ValueToValueMapTy empty;
3038
+ IRBuilder<> lb (V);
3039
+ // This must disallow caching here as otherwise performing the loop in
3040
+ // the wrong order may result in first replacing the later unwrapped
3041
+ // value, caching it, then attempting to reuse it for an earlier
3042
+ // replacement.
3043
+ Value *nval = gutils->unwrapM (nexti, lb, empty,
3044
+ UnwrapMode::LegalFullUnwrapNoTapeReplace,
3045
+ /* scope*/ nullptr , /* permitCache*/ false );
3046
+ assert (nval);
3047
+ V->replaceAllUsesWith (nval);
3048
+ V->eraseFromParent ();
3025
3049
}
3026
3050
}
3027
3051
@@ -3044,13 +3068,17 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
3044
3068
3045
3069
if (auto bi = dyn_cast<BranchInst>(BB.getTerminator ())) {
3046
3070
3047
- Value *vals[1 ] = {gutils->getNewFromOriginal (bi->getCondition ())};
3048
- if (bi->getSuccessor (0 ) == unreachables[0 ]) {
3049
- gutils->replaceAWithB (vals[0 ],
3050
- ConstantInt::getFalse (vals[0 ]->getContext ()));
3051
- } else {
3052
- gutils->replaceAWithB (vals[0 ],
3053
- ConstantInt::getTrue (vals[0 ]->getContext ()));
3071
+ Value *condition = gutils->getNewFromOriginal (bi->getCondition ());
3072
+
3073
+ Constant *repVal = (bi->getSuccessor (0 ) == unreachables[0 ])
3074
+ ? ConstantInt::getFalse (condition->getContext ())
3075
+ : ConstantInt::getTrue (condition->getContext ());
3076
+
3077
+ for (auto UI = condition->use_begin (), E = condition->use_end ();
3078
+ UI != E;) {
3079
+ Use &U = *UI;
3080
+ ++UI;
3081
+ U.set (repVal);
3054
3082
}
3055
3083
}
3056
3084
}
@@ -3219,11 +3247,11 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
3219
3247
PPC.AlwaysInline (gutils->newFunc );
3220
3248
if (Arch == Triple::nvptx || Arch == Triple::nvptx64)
3221
3249
PPC.ReplaceReallocs (gutils->newFunc , /* mem2reg*/ true );
3222
- if (PostOpt)
3223
- PPC.optimizeIntermediate (gutils->newFunc );
3224
3250
3225
3251
auto nf = gutils->newFunc ;
3226
3252
delete gutils;
3253
+ if (PostOpt)
3254
+ PPC.optimizeIntermediate (nf);
3227
3255
if (EnzymePrint) {
3228
3256
llvm::errs () << *nf << " \n " ;
3229
3257
}
0 commit comments