Skip to content

Commit

Permalink
Expose Subtransfer and additional C API functions (rust-lang#384)
Browse files Browse the repository at this point in the history
* [CAPI] Expose subtransfer

* Update

* Fixup

* Update enzyme/Enzyme/CApi.cpp

Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com>

* Update enzyme/Enzyme/CApi.cpp

Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com>

* Update enzyme/Enzyme/CApi.cpp

Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com>

* Update enzyme/Enzyme/CApi.cpp

Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com>

* Update enzyme/Enzyme/CApi.cpp

Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com>

Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com>
  • Loading branch information
wsmoses and vchuravy authored Nov 30, 2021
1 parent cbe56be commit deba550
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 179 deletions.
199 changes: 20 additions & 179 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2166,178 +2166,6 @@ class AdjointGenerator
}
}

void subTransferHelper(Type *secretty, BasicBlock *parent,
Intrinsic::ID intrinsic, unsigned dstalign,
unsigned srcalign, unsigned offset, Value *orig_dst,
Value *orig_src, Value *length, Value *isVolatile,
llvm::CallInst *MTI, bool allowForward = true) {
// TODO offset
if (secretty) {
// no change to forward pass if represents floats
if (Mode == DerivativeMode::ReverseModeGradient ||
Mode == DerivativeMode::ReverseModeCombined) {
IRBuilder<> Builder2(parent);
getReverseBuilder(Builder2);

// If the src is constant simply zero d_dst and don't propagate to d_src
// (which thus == src and may be illegal)
if (gutils->isConstantValue(orig_src)) {
SmallVector<Value *, 4> args;
args.push_back(
lookup(gutils->invertPointerM(orig_dst, Builder2), Builder2));
if (args[0]->getType()->isIntegerTy())
args[0] = Builder2.CreateIntToPtr(
args[0], Type::getInt8PtrTy(MTI->getContext()));
args.push_back(
ConstantInt::get(Type::getInt8Ty(parent->getContext()), 0));
args.push_back(lookup(length, Builder2));
#if LLVM_VERSION_MAJOR <= 6
args.push_back(ConstantInt::get(
Type::getInt32Ty(parent->getContext()), max(1U, dstalign)));
#endif
args.push_back(ConstantInt::getFalse(parent->getContext()));

Type *tys[] = {args[0]->getType(), args[2]->getType()};
auto memsetIntr = Intrinsic::getDeclaration(
parent->getParent()->getParent(), Intrinsic::memset, tys);
auto cal = Builder2.CreateCall(memsetIntr, args);
cal->setCallingConv(memsetIntr->getCallingConv());
if (dstalign != 0) {
#if LLVM_VERSION_MAJOR >= 10
cal->addParamAttr(0, Attribute::getWithAlignment(
parent->getContext(), Align(dstalign)));
#else
cal->addParamAttr(
0, Attribute::getWithAlignment(parent->getContext(), dstalign));
#endif
}

} else {
SmallVector<Value *, 4> args;
auto dsto =
lookup(gutils->invertPointerM(orig_dst, Builder2), Builder2);
if (dsto->getType()->isIntegerTy())
dsto = Builder2.CreateIntToPtr(
dsto, Type::getInt8PtrTy(dsto->getContext()));
unsigned dstaddr =
cast<PointerType>(dsto->getType())->getAddressSpace();
auto secretpt = PointerType::get(secretty, dstaddr);
if (offset != 0)
dsto = Builder2.CreateConstInBoundsGEP1_64(dsto, offset);
args.push_back(Builder2.CreatePointerCast(dsto, secretpt));
auto srco =
lookup(gutils->invertPointerM(orig_src, Builder2), Builder2);
if (srco->getType()->isIntegerTy())
srco = Builder2.CreateIntToPtr(
srco, Type::getInt8PtrTy(srco->getContext()));
unsigned srcaddr =
cast<PointerType>(srco->getType())->getAddressSpace();
secretpt = PointerType::get(secretty, srcaddr);
if (offset != 0)
srco = Builder2.CreateConstInBoundsGEP1_64(srco, offset);
args.push_back(Builder2.CreatePointerCast(srco, secretpt));
args.push_back(Builder2.CreateUDiv(
lookup(length, Builder2),

ConstantInt::get(length->getType(),
Builder2.GetInsertBlock()
->getParent()
->getParent()
->getDataLayout()
.getTypeAllocSizeInBits(secretty) /
8)));

auto dmemcpy = ((intrinsic == Intrinsic::memcpy)
? getOrInsertDifferentialFloatMemcpy
: getOrInsertDifferentialFloatMemmove)(
*parent->getParent()->getParent(), secretty, dstalign, srcalign,
dstaddr, srcaddr);
Builder2.CreateCall(dmemcpy, args);
}
}
} else {

// if represents pointer or integer type then only need to modify forward
// pass with the copy
if (allowForward && (Mode == DerivativeMode::ReverseModePrimal ||
Mode == DerivativeMode::ReverseModeCombined)) {

// It is questionable how the following case would even occur, but if
// the dst is constant, we shouldn't do anything extra
if (gutils->isConstantValue(orig_dst)) {
return;
}

SmallVector<Value *, 4> args;
IRBuilder<> BuilderZ(gutils->getNewFromOriginal(MTI));

// If src is inactive, then we should copy from the regular pointer
// (i.e. suppose we are copying constant memory representing dimensions
// into a tensor)
// to ensure that the differential tensor is well formed for use
// OUTSIDE the derivative generation (as enzyme doesn't need this), we
// should also perform the copy onto the differential. Future
// Optimization (not implemented): If dst can never escape Enzyme code,
// we may omit this copy.
// no need to update pointers, even if dst is active
auto dsto = gutils->invertPointerM(orig_dst, BuilderZ);
if (dsto->getType()->isIntegerTy())
dsto = BuilderZ.CreateIntToPtr(dsto,
Type::getInt8PtrTy(MTI->getContext()));
if (offset != 0)
dsto = BuilderZ.CreateConstInBoundsGEP1_64(dsto, offset);
args.push_back(dsto);
auto srco = gutils->invertPointerM(orig_src, BuilderZ);
if (srco->getType()->isIntegerTy())
srco = BuilderZ.CreateIntToPtr(srco,
Type::getInt8PtrTy(MTI->getContext()));
if (offset != 0)
srco = BuilderZ.CreateConstInBoundsGEP1_64(srco, offset);
args.push_back(srco);

args.push_back(length);
#if LLVM_VERSION_MAJOR <= 6
args.push_back(ConstantInt::get(Type::getInt32Ty(parent->getContext()),
max(1U, min(srcalign, dstalign))));
#endif
args.push_back(isVolatile);

//#if LLVM_VERSION_MAJOR >= 7
Type *tys[] = {args[0]->getType(), args[1]->getType(),
args[2]->getType()};
//#else
// Type *tys[] = {args[0]->getType(), args[1]->getType(),
// args[2]->getType(), args[3]->getType()}; #endif

auto memtransIntr = Intrinsic::getDeclaration(
gutils->newFunc->getParent(), intrinsic, tys);
auto cal = BuilderZ.CreateCall(memtransIntr, args);
cal->setAttributes(MTI->getAttributes());
cal->setCallingConv(memtransIntr->getCallingConv());
cal->setTailCallKind(MTI->getTailCallKind());

if (dstalign != 0) {
#if LLVM_VERSION_MAJOR >= 10
cal->addParamAttr(0, Attribute::getWithAlignment(parent->getContext(),
Align(dstalign)));
#else
cal->addParamAttr(
0, Attribute::getWithAlignment(parent->getContext(), dstalign));
#endif
}
if (srcalign != 0) {
#if LLVM_VERSION_MAJOR >= 10
cal->addParamAttr(1, Attribute::getWithAlignment(parent->getContext(),
Align(srcalign)));
#else
cal->addParamAttr(
1, Attribute::getWithAlignment(parent->getContext(), srcalign));
#endif
}
}
}
}

void visitMemTransferInst(llvm::MemTransferInst &MTI) {
#if LLVM_VERSION_MAJOR >= 7
Value *isVolatile = gutils->getNewFromOriginal(MTI.getOperand(3));
Expand All @@ -2352,16 +2180,20 @@ class AdjointGenerator
auto dstAlign = MTI.getDestAlignment();
#endif
visitMemTransferCommon(MTI.getIntrinsicID(), srcAlign, dstAlign, MTI,
MTI.getOperand(0), MTI.getOperand(1),
gutils->getNewFromOriginal(MTI.getOperand(2)),
isVolatile);
}

#if LLVM_VERSION_MAJOR >= 10
void visitMemTransferCommon(Intrinsic::ID ID, MaybeAlign srcAlign,
MaybeAlign dstAlign, llvm::CallInst &MTI,
Value *orig_dst, Value *orig_src, Value *new_size,
Value *isVolatile)
#else
void visitMemTransferCommon(Intrinsic::ID ID, unsigned srcAlign,
unsigned dstAlign, llvm::CallInst &MTI,
Value *orig_dst, Value *orig_src, Value *new_size,
Value *isVolatile)
#endif
{
Expand All @@ -2375,10 +2207,6 @@ class AdjointGenerator
return;
}

Value *orig_dst = MTI.getOperand(0);
Value *orig_src = MTI.getOperand(1);
Value *new_size = gutils->getNewFromOriginal(MTI.getOperand(2));

// copying into nullptr is invalid (not sure why it exists here), but we
// shouldn't do it in reverse pass or shadow
if (isa<ConstantPointerNull>(orig_dst) ||
Expand Down Expand Up @@ -2529,8 +2357,17 @@ class AdjointGenerator
srcalign = 1;
}
}
subTransferHelper(dt.isFloat(), MTI.getParent(), ID, subdstalign,
subsrcalign, /*offset*/ start, orig_dst, orig_src,
IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&MTI));
Value *shadow_dst = gutils->isConstantValue(orig_dst)
? gutils->getNewFromOriginal(orig_dst)
: gutils->invertPointerM(orig_dst, BuilderZ);
Value *shadow_src = gutils->isConstantValue(orig_src)
? gutils->getNewFromOriginal(orig_src)
: gutils->invertPointerM(orig_src, BuilderZ);
SubTransferHelper(gutils, Mode, dt.isFloat(), ID, subdstalign,
subsrcalign, /*offset*/ start,
gutils->isConstantValue(orig_dst), shadow_dst,
gutils->isConstantValue(orig_src), shadow_src,
/*length*/ length, /*volatile*/ isVolatile, &MTI);

if (nextStart == size)
Expand Down Expand Up @@ -7741,10 +7578,14 @@ class AdjointGenerator
#if LLVM_VERSION_MAJOR >= 10
visitMemTransferCommon(ID, /*srcAlign*/ MaybeAlign(1),
/*dstAlign*/ MaybeAlign(1), *orig,
orig->getArgOperand(0), orig->getArgOperand(1),
gutils->getNewFromOriginal(orig->getArgOperand(2)),
ConstantInt::getFalse(orig->getContext()));
#else
visitMemTransferCommon(ID, /*srcAlign*/ 1,
/*dstAlign*/ 1, *orig,
/*dstAlign*/ 1, *orig, orig->getArgOperand(0),
orig->getArgOperand(1),
gutils->getNewFromOriginal(orig->getArgOperand(2)),
ConstantInt::getFalse(orig->getContext()));
#endif
return;
Expand Down
38 changes: 38 additions & 0 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ LLVMValueRef EnzymeGradientUtilsNewFromOriginal(GradientUtils *gutils,
return wrap(gutils->getNewFromOriginal(unwrap(val)));
}

CDerivativeMode EnzymeGradientUtilsGetMode(GradientUtils *gutils) {
return (CDerivativeMode)gutils->mode;
}

void EnzymeGradientUtilsSetDebugLocFromOriginal(GradientUtils *gutils,
LLVMValueRef val,
LLVMValueRef orig) {
Expand Down Expand Up @@ -331,6 +335,31 @@ LLVMBasicBlockRef EnzymeGradientUtilsAllocationBlock(GradientUtils *gutils) {
return wrap(gutils->inversionAllocs);
}

CTypeTreeRef EnzymeGradientUtilsAllocAndGetTypeTree(GradientUtils *gutils,
LLVMValueRef val) {
auto v = unwrap(val);
assert(gutils->my_TR);
TypeTree TT = gutils->my_TR->query(v);
TypeTree *pTT = new TypeTree(TT);
return (CTypeTreeRef)pTT;
}

void EnzymeGradientUtilsSubTransferHelper(
GradientUtils *gutils, CDerivativeMode mode, LLVMTypeRef secretty,
uint64_t intrinsic, uint64_t dstAlign, uint64_t srcAlign, uint64_t offset,
uint8_t dstConstant, LLVMValueRef shadow_dst, uint8_t srcConstant,
LLVMValueRef shadow_src, LLVMValueRef length, LLVMValueRef isVolatile,
LLVMValueRef MTI, uint8_t allowForward) {
auto orig = unwrap(MTI);
assert(orig);
SubTransferHelper(gutils, (DerivativeMode)mode, unwrap(secretty),
(Intrinsic::ID)intrinsic, (unsigned)dstAlign,
(unsigned)srcAlign, (unsigned)offset, (bool)dstConstant,
unwrap(shadow_dst), (bool)srcConstant, unwrap(shadow_src),
unwrap(length), unwrap(isVolatile), cast<CallInst>(orig),
(bool)allowForward);
}

LLVMValueRef EnzymeCreateForwardDiff(
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
Expand Down Expand Up @@ -467,6 +496,15 @@ void EnzymeTypeTreeOnlyEq(CTypeTreeRef CTT, int64_t x) {
void EnzymeTypeTreeData0Eq(CTypeTreeRef CTT) {
*(TypeTree *)CTT = ((TypeTree *)CTT)->Data0();
}

void EnzymeTypeTreeLookupEq(CTypeTreeRef CTT, int64_t size, const char *dl) {
*(TypeTree *)CTT = ((TypeTree *)CTT)->Lookup(size, DataLayout(dl));
}

CConcreteType EnzymeTypeTreeInner0(CTypeTreeRef CTT) {
return ewrap(((TypeTree *)CTT)->Inner0());
}

void EnzymeTypeTreeShiftIndiciesEq(CTypeTreeRef CTT, const char *datalayout,
int64_t offset, int64_t maxSize,
uint64_t addOffset) {
Expand Down
Loading

0 comments on commit deba550

Please sign in to comment.