Skip to content

Commit 8be573d

Browse files
wsmosesvchuravy
andauthored
Permit Zero Initialization of Cache (rust-lang#386)
* Zero Outermost Cache * WIP: Silly billy * Zero cache option Co-authored-by: Valentin Churavy <v.churavy@gmail.com>
1 parent 779ac0f commit 8be573d

File tree

5 files changed

+47
-4
lines changed

5 files changed

+47
-4
lines changed

enzyme/Enzyme/CacheUtility.cpp

+24-2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ llvm::cl::opt<bool>
3535
EfficientBoolCache("enzyme-smallbool", cl::init(false), cl::Hidden,
3636
cl::desc("Place 8 bools together in a single byte"));
3737

38+
llvm::cl::opt<bool> EnzymeZeroCache("enzyme-zero-cache", cl::init(false),
39+
cl::Hidden,
40+
cl::desc("Zero initialize the cache"));
41+
3842
llvm::cl::opt<bool>
3943
EnzymePrintPerf("enzyme-print-perf", cl::init(false), cl::Hidden,
4044
cl::desc("Enable Enzyme to print performance info"));
@@ -640,6 +644,9 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T,
640644
#endif
641645
}
642646
}
647+
if (EnzymeZeroCache && sublimits.size() == 0)
648+
scopeInstructions[alloc].push_back(
649+
entryBuilder.CreateStore(Constant::getNullValue(types.back()), alloc));
643650

644651
Type *BPTy = Type::getInt8PtrTy(T->getContext());
645652

@@ -699,6 +706,20 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T,
699706
cast<CallInst>(cast<Instruction>(firstallocation)->getOperand(0));
700707
}
701708

709+
if (EnzymeZeroCache && i == 0) {
710+
Value *args[] = {
711+
malloccall,
712+
ConstantInt::get(Type::getInt8Ty(malloccall->getContext()), 0),
713+
malloccall->getArgOperand(0),
714+
ConstantInt::getFalse(malloccall->getContext())};
715+
Type *tys[] = {args[0]->getType(), args[2]->getType()};
716+
717+
scopeInstructions[alloc].push_back(allocationBuilder.CreateCall(
718+
Intrinsic::getDeclaration(newFunc->getParent(), Intrinsic::memset,
719+
tys),
720+
args));
721+
}
722+
702723
// Assert computation of size of array doesn't wrap
703724
if (auto BI = dyn_cast<BinaryOperator>(malloccall->getArgOperand(0))) {
704725
if ((BI->getOperand(0) == byteSizeOfType &&
@@ -788,8 +809,9 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T,
788809
Value *realloccall = nullptr;
789810

790811
realloccall = build.CreateCall(
791-
getOrInsertExponentialAllocator(*newFunc->getParent()), idxs,
792-
name + "_realloccache");
812+
getOrInsertExponentialAllocator(*newFunc->getParent(),
813+
EnzymeZeroCache && i == 0),
814+
idxs, name + "_realloccache");
793815
scopeAllocs[alloc].push_back(cast<CallInst>(realloccall));
794816
allocation = build.CreateBitCast(realloccall, allocation->getType());
795817
storealloc = build.CreateStore(allocation, storeInto);

enzyme/Enzyme/CacheUtility.h

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
extern "C" {
4444
/// Pack 8 bools together in a single byte
4545
extern llvm::cl::opt<bool> EfficientBoolCache;
46+
47+
extern llvm::cl::opt<bool> EnzymeZeroCache;
4648
}
4749

4850
/// Container for all loop information to synthesize gradients

enzyme/Enzyme/EnzymeLogic.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -2144,6 +2144,9 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
21442144
IRBuilder<> ib(NewF->getEntryBlock().getFirstNonPHI());
21452145

21462146
Value *ret = noReturn ? nullptr : ib.CreateAlloca(RetType);
2147+
if (!noReturn && EnzymeZeroCache) {
2148+
ib.CreateStore(Constant::getNullValue(RetType), ret);
2149+
}
21472150

21482151
if (!noTape) {
21492152
Value *tapeMemory;
@@ -2169,6 +2172,19 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
21692172
malloccall->addAttribute(AttributeList::ReturnIndex, attr);
21702173
#endif
21712174
}
2175+
if (EnzymeZeroCache) {
2176+
IRBuilder<> B(malloccall->getNextNode());
2177+
Value *args[] = {
2178+
malloccall,
2179+
ConstantInt::get(Type::getInt8Ty(malloccall->getContext()), 0),
2180+
malloccall->getArgOperand(0),
2181+
ConstantInt::getFalse(malloccall->getContext())};
2182+
Type *tys[] = {args[0]->getType(), args[2]->getType()};
2183+
2184+
B.CreateCall(Intrinsic::getDeclaration(NewF->getParent(),
2185+
Intrinsic::memset, tys),
2186+
args);
2187+
}
21722188
#if LLVM_VERSION_MAJOR >= 14
21732189
malloccall->addDereferenceableRetAttr(size->getLimitedValue());
21742190
AttrBuilder B;

enzyme/Enzyme/Utils.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -531,11 +531,14 @@ llvm::Value *getOrInsertOpFloatSum(llvm::Module &M, llvm::Type *OpPtr,
531531
return B2.CreateLoad(GV);
532532
}
533533

534-
Function *getOrInsertExponentialAllocator(Module &M) {
534+
Function *getOrInsertExponentialAllocator(Module &M, bool ZeroInit) {
535535
Type *BPTy = Type::getInt8PtrTy(M.getContext());
536536
Type *types[] = {BPTy, Type::getInt64Ty(M.getContext()),
537537
Type::getInt64Ty(M.getContext())};
538538
std::string name = "__enzyme_exponentialallocation";
539+
if (ZeroInit)
540+
name += "zero";
541+
assert(!ZeroInit && "Zero initialization within reallocation not handled");
539542
FunctionType *FT =
540543
FunctionType::get(Type::getInt8PtrTy(M.getContext()), types, false);
541544

enzyme/Enzyme/Utils.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ enum class MPI_CallType {
833833
llvm::Value *getOrInsertOpFloatSum(llvm::Module &M, llvm::Type *OpPtr,
834834
ConcreteType CT, llvm::Type *intType,
835835
llvm::IRBuilder<> &B2);
836-
llvm::Function *getOrInsertExponentialAllocator(llvm::Module &M);
836+
llvm::Function *getOrInsertExponentialAllocator(llvm::Module &M, bool ZeroInit);
837837

838838
class AssertingReplacingVH : public llvm::CallbackVH {
839839
public:

0 commit comments

Comments
 (0)