Skip to content

Commit

Permalink
Handle dupnoneed in forward mode (rust-lang#396)
Browse files Browse the repository at this point in the history
* Handle dupnoneed in forward mode

* Fix allocs with more than one parameter (rust-lang#397)

* Add test
  • Loading branch information
tgymnich authored Dec 16, 2021
1 parent 394992e commit 3872f89
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 2 deletions.
6 changes: 5 additions & 1 deletion enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -7455,7 +7455,11 @@ class AdjointGenerator
IRBuilder<> Builder2(&call);
getForwardBuilder(Builder2);

SmallVector<Value *, 2> args = {orig->getArgOperand(0)};
SmallVector<Value *, 2> args;
for (unsigned i = 0; i < orig->getNumArgOperands(); ++i) {
auto arg = orig->getArgOperand(i);
args.push_back(gutils->getNewFromOriginal(arg));
}
CallInst *CI = Builder2.CreateCall(orig->getFunctionType(),
orig->getCalledFunction(), args);
CI->setAttributes(orig->getAttributes());
Expand Down
10 changes: 9 additions & 1 deletion enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3848,7 +3848,9 @@ Function *EnzymeLogic::CreateForwardDiff(

for (size_t i = 0; i < constant_args.size(); ++i) {
auto arg = constant_args[i];
if (arg == DIFFE_TYPE::DUP_ARG) {
switch (arg) {
case DIFFE_TYPE::DUP_ARG:
case DIFFE_TYPE::DUP_NONEED: {
newArgs += 1;
auto pri = gutils->oldFunc->arg_begin() + i;
auto dif = newArgs;
Expand All @@ -3857,6 +3859,12 @@ Function *EnzymeLogic::CreateForwardDiff(
IRBuilder<> Builder(&BB.front());

gutils->setDiffe(pri, dif, Builder);
break;
}
case DIFFE_TYPE::CONSTANT:
break;
case DIFFE_TYPE::OUT_DIFF:
report_fatal_error("unsupported DIFFE_TYPE");
}
newArgs += 1;
}
Expand Down
40 changes: 40 additions & 0 deletions enzyme/test/Enzyme/ForwardMode/calloc.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -simplifycfg -S | FileCheck %s


@enzyme_dupnoneed = dso_local global i32 0, align 4

define dso_local double @f(double %x, i64 %arg) {
entry:
%call = call noalias i8* @calloc(i64 8, i64 %arg)
%0 = bitcast i8* %call to double*
store double %x, double* %0, align 8
%1 = load double, double* %0, align 8
ret double %1
}

declare dso_local noalias i8* @calloc(i64, i64)

define dso_local double @df(double %x) {
entry:
%x.addr = alloca double, align 8
store double %x, double* %x.addr, align 8
%0 = load i32, i32* @enzyme_dupnoneed, align 4
%1 = load double, double* %x.addr, align 8
%call = call double (i8*, ...) @__enzyme_fwddiff(i8* bitcast (double (double,i64)* @f to i8*), i32 %0, double %1, double 1.000000e+00, i64 1)
ret double %call
}

declare dso_local double @__enzyme_fwddiff(i8*, ...)


; CHECK: define internal double @fwddiffef(double %x, double %"x'", i64 %arg)
; CHECK-NEXT: entry:
; CHECK-NEXT: %call = call noalias i8* @calloc(i64 8, i64 %arg)
; CHECK-NEXT: %0 = call noalias i8* @calloc(i64 8, i64 %arg)
; CHECK-NEXT: %"'ipc" = bitcast i8* %0 to double*
; CHECK-NEXT: %1 = bitcast i8* %call to double*
; CHECK-NEXT: store double %x, double* %1, align 8
; CHECK-NEXT: store double %"x'", double* %"'ipc", align 8
; CHECK-NEXT: %2 = load double, double* %"'ipc", align 8
; CHECK-NEXT: ret double %2
; CHECK-NEXT: }

0 comments on commit 3872f89

Please sign in to comment.