Skip to content

Commit ab357c2

Browse files
authored
Improve sret style memory handling (rust-lang#860)
* Fix sret julia GC issue * Simplify * Fixup * Fix * Fix version inv
1 parent a6a92b4 commit ab357c2

10 files changed

+704
-398
lines changed

Diff for: enzyme/Enzyme/AdjointGenerator.h

+34-1
Original file line numberDiff line numberDiff line change
@@ -10377,7 +10377,12 @@ class AdjointGenerator
1037710377
if (!forwardsShadow) {
1037810378
if (Mode == DerivativeMode::ReverseModePrimal) {
1037910379
// Needs a stronger replacement check/assertion.
10380-
Value *replacement = UndefValue::get(placeholder->getType());
10380+
Value *replacement;
10381+
if (EnzymeZeroCache)
10382+
replacement = ConstantPointerNull::get(
10383+
cast<PointerType>(placeholder->getType()));
10384+
else
10385+
replacement = UndefValue::get(placeholder->getType());
1038110386
gutils->replaceAWithB(placeholder, replacement);
1038210387
gutils->invertedPointers.erase(found);
1038310388
gutils->invertedPointers.insert(std::make_pair(
@@ -11615,6 +11620,34 @@ class AdjointGenerator
1161511620
gradByVal[args.size()] = orig->getParamByValType(i);
1161611621
}
1161711622
#endif
11623+
11624+
bool writeOnlyNoCapture = true;
11625+
#if LLVM_VERSION_MAJOR >= 8
11626+
if (!orig->doesNotCapture(i))
11627+
#else
11628+
if (!(orig->dataOperandHasImpliedAttr(i + 1, Attribute::NoCapture) ||
11629+
(called && called->hasParamAttribute(i, Attribute::NoCapture))))
11630+
#endif
11631+
{
11632+
writeOnlyNoCapture = false;
11633+
}
11634+
#if LLVM_VERSION_MAJOR >= 14
11635+
if (!orig->onlyWritesMemory(i))
11636+
#else
11637+
if (!(orig->dataOperandHasImpliedAttr(i + 1, Attribute::WriteOnly) ||
11638+
orig->dataOperandHasImpliedAttr(i + 1, Attribute::ReadNone) ||
11639+
(called && (called->hasParamAttribute(i, Attribute::WriteOnly) ||
11640+
called->hasParamAttribute(i, Attribute::ReadNone)))))
11641+
#endif
11642+
{
11643+
writeOnlyNoCapture = false;
11644+
}
11645+
if (writeOnlyNoCapture) {
11646+
if (EnzymeZeroCache)
11647+
argi = ConstantPointerNull::get(cast<PointerType>(argi->getType()));
11648+
else
11649+
argi = UndefValue::get(argi->getType());
11650+
}
1161811651
args.push_back(lookup(argi, Builder2));
1161911652
}
1162011653

Diff for: enzyme/Enzyme/DifferentialUseAnalysis.h

+37-2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929

3030
#include "GradientUtils.h"
3131

32-
typedef std::pair<const Value *, ValueType> UsageKey;
33-
3432
// Determine if a value is needed directly to compute the adjoint
3533
// of the given instruction user
3634
static inline bool is_use_directly_needed_in_reverse(
@@ -297,6 +295,43 @@ static inline bool is_use_directly_needed_in_reverse(
297295
// we still need even if instruction is inactive
298296
if (funcName == "llvm.julia.gc_preserve_begin")
299297
return true;
298+
299+
bool writeOnlyNoCapture = true;
300+
auto F = getFunctionFromCall(const_cast<CallInst *>(CI));
301+
#if LLVM_VERSION_MAJOR >= 14
302+
for (size_t i = 0; i < CI->arg_size(); i++)
303+
#else
304+
for (size_t i = 0; i < CI->getNumArgOperands(); i++)
305+
#endif
306+
{
307+
if (val == CI->getArgOperand(i)) {
308+
#if LLVM_VERSION_MAJOR >= 8
309+
if (!CI->doesNotCapture(i))
310+
#else
311+
if (!(CI->dataOperandHasImpliedAttr(i + 1, Attribute::NoCapture) ||
312+
(F && F->hasParamAttribute(i, Attribute::NoCapture))))
313+
#endif
314+
{
315+
writeOnlyNoCapture = false;
316+
break;
317+
}
318+
#if LLVM_VERSION_MAJOR >= 14
319+
if (!CI->onlyWritesMemory(i))
320+
#else
321+
if (!(CI->dataOperandHasImpliedAttr(i + 1, Attribute::WriteOnly) ||
322+
CI->dataOperandHasImpliedAttr(i + 1, Attribute::ReadNone) ||
323+
(F && (F->hasParamAttribute(i, Attribute::WriteOnly) ||
324+
F->hasParamAttribute(i, Attribute::ReadNone)))))
325+
#endif
326+
{
327+
writeOnlyNoCapture = false;
328+
break;
329+
}
330+
}
331+
}
332+
// Don't need the primal argument if it is write only and not captured
333+
if (writeOnlyNoCapture)
334+
return false;
300335
}
301336

302337
return !gutils->isConstantInstruction(user) ||

Diff for: enzyme/Enzyme/EnzymeLogic.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ struct CacheAnalysis {
185185
// Pointer operands originating from call instructions that are not
186186
// malloc/free are conservatively considered uncacheable.
187187
if (auto obj_op = dyn_cast<CallInst>(obj)) {
188+
auto n = getFuncNameFromCall(obj_op);
188189
// If this is a known allocation which is not captured or returned,
189190
// a caller function cannot overwrite this (since it cannot access).
190191
// Since we don't currently perform this check, we can instead check
@@ -193,6 +194,9 @@ struct CacheAnalysis {
193194
if (allocationsWithGuaranteedFree.find(obj_op) !=
194195
allocationsWithGuaranteedFree.end()) {
195196

197+
} else if (n == "julia.get_pgcstack" || n == "julia.ptls_states" ||
198+
n == "jl_get_ptls_states") {
199+
196200
} else {
197201
// OP is a non malloc/free call so we need to cache
198202
mustcache = true;
@@ -267,6 +271,13 @@ struct CacheAnalysis {
267271
oldFunc->getParent()->getDataLayout(), 100);
268272
#endif
269273

274+
if (auto obj_op = dyn_cast<CallInst>(obj)) {
275+
auto n = getFuncNameFromCall(obj_op);
276+
if (n == "julia.get_pgcstack" || n == "julia.ptls_states" ||
277+
n == "jl_get_ptls_states")
278+
return false;
279+
}
280+
270281
// Openmp bound and local thread id are unchanging
271282
// definitionally cacheable.
272283
if (omp)

Diff for: enzyme/Enzyme/FunctionUtils.cpp

+17-1
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ void RecursivelyReplaceAddressSpace(Value *AI, Value *rep, bool legal) {
285285
auto AS = cast<PointerType>(rep->getType())->getAddressSpace();
286286
if (AS == ASC->getDestAddressSpace()) {
287287
ASC->replaceAllUsesWith(rep);
288+
toErase.push_back(ASC);
288289
continue;
289290
}
290291
ASC->setOperand(0, rep);
@@ -360,13 +361,28 @@ void RecursivelyReplaceAddressSpace(Value *AI, Value *rep, bool legal) {
360361
continue;
361362
}
362363
}
364+
IRBuilder<> B(CI);
365+
auto Addr = B.CreateAddrSpaceCast(rep, prev->getType());
366+
#if LLVM_VERSION_MAJOR >= 14
367+
for (size_t i = 0; i < CI->arg_size(); i++)
368+
#else
369+
for (size_t i = 0; i < CI->getNumArgOperands(); i++)
370+
#endif
371+
{
372+
if (CI->getArgOperand(i) == prev) {
373+
CI->setArgOperand(i, Addr);
374+
}
375+
}
376+
continue;
363377
}
364378
llvm::errs() << " rep: " << *rep << " prev: " << *prev << " inst: " << *inst
365379
<< "\n";
366380
llvm_unreachable("Illegal address space propagation");
367381
}
368-
for (auto I : llvm::reverse(toErase))
382+
383+
for (auto I : llvm::reverse(toErase)) {
369384
I->eraseFromParent();
385+
}
370386
for (auto SI : toPostCache) {
371387
IRBuilder<> B(SI->getNextNode());
372388
PostCacheStore(SI, B);

0 commit comments

Comments
 (0)