Skip to content

Commit

Permalink
Correct Forward mode custom call convention (rust-lang#413)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jan 4, 2022
1 parent 54c93e6 commit b373450
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 29 deletions.
27 changes: 17 additions & 10 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -6272,15 +6272,23 @@ class AdjointGenerator
DIFFE_TYPE subretType;
if (gutils->isConstantValue(orig)) {
subretType = DIFFE_TYPE::CONSTANT;
} else if (!orig->getType()->isFPOrFPVectorTy() &&
TR.query(orig).Inner0().isPossiblePointer()) {
if (is_value_needed_in_reverse<ValueType::ShadowPtr>(
TR, gutils, orig, Mode, oldUnreachable))
subretType = DIFFE_TYPE::DUP_ARG;
else
subretType = DIFFE_TYPE::CONSTANT;
} else {
subretType = DIFFE_TYPE::OUT_DIFF;
if (Mode == DerivativeMode::ForwardMode ||
Mode == DerivativeMode::ForwardModeSplit ||
Mode == DerivativeMode::ForwardModeVector) {
subretType = DIFFE_TYPE::DUP_ARG;
} else {
if (!orig->getType()->isFPOrFPVectorTy() &&
TR.query(orig).Inner0().isPossiblePointer()) {
if (is_value_needed_in_reverse<ValueType::ShadowPtr>(
TR, gutils, orig, Mode, oldUnreachable))
subretType = DIFFE_TYPE::DUP_ARG;
else
subretType = DIFFE_TYPE::CONSTANT;
} else {
subretType = DIFFE_TYPE::OUT_DIFF;
}
}
}

if (Mode == DerivativeMode::ForwardMode) {
Expand Down Expand Up @@ -8133,8 +8141,7 @@ class AdjointGenerator
newcalled = gutils->Logic.CreateForwardDiff(
cast<Function>(called), subretType, argsInverted, gutils->TLI,
TR.analyzer.interprocedural, /*returnValue*/ subretused,
/*subdretptr*/ false, DerivativeMode::ForwardMode, nullptr,
nextTypeInfo, {});
DerivativeMode::ForwardMode, nullptr, nextTypeInfo, {});
} else {
#if LLVM_VERSION_MAJOR >= 11
auto callval = orig->getCalledOperand();
Expand Down
8 changes: 4 additions & 4 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,9 @@ void EnzymeGradientUtilsSubTransferHelper(
LLVMValueRef EnzymeCreateForwardDiff(
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed,
CDerivativeMode mode, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo,
uint8_t *_uncacheable_args, size_t uncacheable_args_size, uint8_t PostOpt) {
EnzymeTypeAnalysisRef TA, uint8_t returnValue, CDerivativeMode mode,
LLVMTypeRef additionalArg, CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
size_t uncacheable_args_size, uint8_t PostOpt) {
std::vector<DIFFE_TYPE> nconstant_args((DIFFE_TYPE *)constant_args,
(DIFFE_TYPE *)constant_args +
constant_args_size);
Expand All @@ -391,7 +391,7 @@ LLVMValueRef EnzymeCreateForwardDiff(
}
return wrap(eunwrap(Logic).CreateForwardDiff(
cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
eunwrap(TA).TLI, eunwrap(TA), returnValue, dretUsed, (DerivativeMode)mode,
eunwrap(TA).TLI, eunwrap(TA), returnValue, (DerivativeMode)mode,
unwrap(additionalArg), eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
uncacheable_args, PostOpt));
}
Expand Down
14 changes: 6 additions & 8 deletions enzyme/Enzyme/CApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,12 @@ typedef enum {
DEM_ReverseModeCombined = 3,
} CDerivativeMode;

LLVMValueRef
EnzymeCreateForwardDiff(EnzymeLogicRef, LLVMValueRef todiff,
CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
size_t constant_args_size, EnzymeTypeAnalysisRef TA,
uint8_t returnValue, uint8_t dretUsed,
CDerivativeMode mode, LLVMTypeRef additionalArg,
struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
size_t uncacheable_args_size, uint8_t PostOpt);
LLVMValueRef EnzymeCreateForwardDiff(
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
EnzymeTypeAnalysisRef TA, uint8_t returnValue, CDerivativeMode mode,
LLVMTypeRef additionalArg, struct CFnTypeInfo typeInfo,
uint8_t *_uncacheable_args, size_t uncacheable_args_size, uint8_t PostOpt);

LLVMValueRef EnzymeCreatePrimalAndGradient(
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ class Enzyme : public ModulePass {
case DerivativeMode::ForwardMode:
newFunc = Logic.CreateForwardDiff(
cast<Function>(fn), retType, constants, TLI, TA,
/*should return*/ false, /*dretPtr*/ false, mode,
/*should return*/ false, mode,
/*addedType*/ nullptr, type_args, volatile_args, PostOpt);
break;
case DerivativeMode::ReverseModeCombined:
Expand Down
54 changes: 52 additions & 2 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3724,10 +3724,11 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
Function *EnzymeLogic::CreateForwardDiff(
Function *todiff, DIFFE_TYPE retType,
const std::vector<DIFFE_TYPE> &constant_args, TargetLibraryInfo &TLI,
TypeAnalysis &TA, bool returnUsed, bool dretPtr, DerivativeMode mode,
TypeAnalysis &TA, bool returnUsed, DerivativeMode mode,
llvm::Type *additionalArg, const FnTypeInfo &oldTypeInfo_,
const std::map<Argument *, bool> _uncacheable_args, bool PostOpt,
bool omp) {
assert(retType != DIFFE_TYPE::OUT_DIFF);

assert(mode == DerivativeMode::ForwardMode ||
mode == DerivativeMode::ForwardModeVector ||
Expand All @@ -3742,14 +3743,15 @@ Function *EnzymeLogic::CreateForwardDiff(
std::make_tuple(todiff, retType, constant_args,
std::map<Argument *, bool>(_uncacheable_args.begin(),
_uncacheable_args.end()),
returnUsed, dretPtr, mode, additionalArg, oldTypeInfo);
returnUsed, mode, additionalArg, oldTypeInfo);
if (ForwardCachedFunctions.find(tup) != ForwardCachedFunctions.end()) {
return ForwardCachedFunctions.find(tup)->second;
}

// TODO change this to go by default function type assumptions
bool hasconstant = false;
for (auto v : constant_args) {
assert(v != DIFFE_TYPE::OUT_DIFF);
if (v == DIFFE_TYPE::CONSTANT) {
hasconstant = true;
break;
Expand All @@ -3769,6 +3771,54 @@ Function *EnzymeLogic::CreateForwardDiff(
auto gvemd = cast<ConstantAsMetadata>(md2->getOperand(0));
auto foundcalled = cast<Function>(gvemd->getValue());

if (!foundcalled->getReturnType()->isVoidTy()) {
if (returnUsed && retType == DIFFE_TYPE::CONSTANT) {
FunctionType *FTy = FunctionType::get(
todiff->getReturnType(), foundcalled->getFunctionType()->params(),
foundcalled->getFunctionType()->isVarArg());
Function *NewF = Function::Create(
FTy, Function::LinkageTypes::InternalLinkage,
"fixderivative_" + todiff->getName(), todiff->getParent());
for (auto pair : llvm::zip(NewF->args(), foundcalled->args())) {
std::get<0>(pair).setName(std::get<1>(pair).getName());
}

BasicBlock *BB = BasicBlock::Create(NewF->getContext(), "entry", NewF);
IRBuilder<> bb(BB);
SmallVector<Value *, 2> args;
for (auto &a : NewF->args())
args.push_back(&a);
auto cal = bb.CreateCall(foundcalled, args);
cal->setCallingConv(foundcalled->getCallingConv());

bb.CreateRet(bb.CreateExtractValue(cal, 0));
return ForwardCachedFunctions[tup] = NewF;
}
if (!returnUsed && retType != DIFFE_TYPE::CONSTANT) {
FunctionType *FTy = FunctionType::get(
todiff->getReturnType(), foundcalled->getFunctionType()->params(),
foundcalled->getFunctionType()->isVarArg());
Function *NewF = Function::Create(
FTy, Function::LinkageTypes::InternalLinkage,
"fixderivative_" + todiff->getName(), todiff->getParent());
for (auto pair : llvm::zip(NewF->args(), foundcalled->args())) {
std::get<0>(pair).setName(std::get<1>(pair).getName());
}

BasicBlock *BB = BasicBlock::Create(NewF->getContext(), "entry", NewF);
IRBuilder<> bb(BB);
SmallVector<Value *, 2> args;
for (auto &a : NewF->args())
args.push_back(&a);
auto cal = bb.CreateCall(foundcalled, args);
cal->setCallingConv(foundcalled->getCallingConv());

bb.CreateRet(bb.CreateExtractValue(cal, 1));
return ForwardCachedFunctions[tup] = NewF;
}
assert(returnUsed);
}

return foundcalled;
}
if (todiff->empty())
Expand Down
6 changes: 3 additions & 3 deletions enzyme/Enzyme/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ class EnzymeLogic {
std::tuple<llvm::Function *, DIFFE_TYPE /*retType*/,
std::vector<DIFFE_TYPE> /*constant_args*/,
std::map<llvm::Argument *, bool> /*uncacheable_args*/,
bool /*retval*/, bool /*dretPtr*/, DerivativeMode,
llvm::Type *, const FnTypeInfo>;
bool /*retval*/, DerivativeMode, llvm::Type *,
const FnTypeInfo>;
std::map<ForwardCacheKey, llvm::Function *> ForwardCachedFunctions;

/// Create the derivative function itself.
Expand Down Expand Up @@ -287,7 +287,7 @@ class EnzymeLogic {
CreateForwardDiff(llvm::Function *todiff, DIFFE_TYPE retType,
const std::vector<DIFFE_TYPE> &constant_args,
llvm::TargetLibraryInfo &TLI, TypeAnalysis &TA,
bool returnValue, bool dretUsed, DerivativeMode mode,
bool returnValue, DerivativeMode mode,
llvm::Type *additionalArg, const FnTypeInfo &typeInfo,
const std::map<llvm::Argument *, bool> _uncacheable_args,
bool PostOpt = false, bool omp = false);
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2718,7 +2718,7 @@ Constant *GradientUtils::GetOrCreateShadowFunction(
switch (mode) {
case DerivativeMode::ForwardMode: {
Constant *newf =
Logic.CreateForwardDiff(fn, retType, types, TLI, TA, false, false, mode,
Logic.CreateForwardDiff(fn, retType, types, TLI, TA, false, mode,
nullptr, type_args, uncacheable_args);

if (!newf)
Expand Down
42 changes: 42 additions & 0 deletions enzyme/test/Enzyme/ForwardMode/custom0.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -early-cse -S | FileCheck %s

source_filename = "exer2.c"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

@__enzyme_register_derivative_add = dso_local local_unnamed_addr global [2 x i8*] [i8* bitcast (double (double, double)* @add to i8*), i8* bitcast ({ double, double } (double, double, double, double)* @add_err to i8*)], align 16

declare double @add(double %x, double %y) #0

declare { double, double } @add_err(double %v1, double %v1err, double %v2, double %v2err)

; Function Attrs: norecurse nounwind readnone uwtable willreturn
define double @f(double %x) {
entry:
%call = call double @add(double %x, double %x)
ret double %call
}

; Function Attrs: nounwind uwtable
define double @caller(double %x, double %dx) {
entry:
%call = call double (i8*, ...) @__enzyme_fwddiff(i8* bitcast (double (double)* @f to i8*), double %x, double %dx)
ret double %call
}

declare dso_local double @__enzyme_fwddiff(i8*, ...)

attributes #0 = { norecurse nounwind readnone }

; CHECK: define internal double @fwddiffef(double %x, double %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call fast double @fixderivative_add(double %x, double %"x'", double %x, double %"x'")
; CHECK-NEXT: ret double %0
; CHECK-NEXT: }

; CHECK: define internal double @fixderivative_add(double %v1, double %v1err, double %v2, double %v2err) {
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call { double, double } @add_err(double %v1, double %v1err, double %v2, double %v2err)
; CHECK-NEXT: %1 = extractvalue { double, double } %0, 1
; CHECK-NEXT: ret double %1
; CHECK-NEXT: }
40 changes: 40 additions & 0 deletions enzyme/test/Enzyme/ForwardMode/custom1.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -early-cse -S | FileCheck %s

source_filename = "exer2.c"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

@__enzyme_register_derivative_add = dso_local local_unnamed_addr global [2 x i8*] [i8* bitcast (double (double, double)* @add to i8*), i8* bitcast ({ double, double } (double, double, double, double)* @add_err to i8*)], align 16

declare double @add(double %x, double %y) #0

declare { double, double } @add_err(double %v1, double %v1err, double %v2, double %v2err)

; Function Attrs: norecurse nounwind readnone uwtable willreturn
define double @f(double %x) {
entry:
%call = call double @add(double %x, double %x)
%mul = fmul double %call, %call
ret double %mul
}

; Function Attrs: nounwind uwtable
define double @caller(double %x, double %dx) {
entry:
%call = call double (i8*, ...) @__enzyme_fwddiff(i8* bitcast (double (double)* @f to i8*), double %x, double %dx)
ret double %call
}

declare dso_local double @__enzyme_fwddiff(i8*, ...)

attributes #0 = { norecurse nounwind readnone }

; CHECK: define internal double @fwddiffef(double %x, double %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call { double, double } @add_err(double %x, double %"x'", double %x, double %"x'")
; CHECK-NEXT: %1 = extractvalue { double, double } %0, 0
; CHECK-NEXT: %2 = extractvalue { double, double } %0, 1
; CHECK-NEXT: %3 = fmul fast double %2, %1
; CHECK-NEXT: %4 = fadd fast double %3, %3
; CHECK-NEXT: ret double %4
; CHECK-NEXT: }

0 comments on commit b373450

Please sign in to comment.