Skip to content

Commit

Permalink
Handle Pointer Returns in Forward Mode (rust-lang#352)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich authored Oct 13, 2021
1 parent 534beda commit ffe46d5
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 11 deletions.
17 changes: 13 additions & 4 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -7757,10 +7757,19 @@ class AdjointGenerator
} else {
diffe = diffes;
}
gutils->replaceAWithB(newcall, diffe);
gutils->erase(newcall);
if (!gutils->isConstantValue(&call))
setDiffe(&call, diffe, Builder2);

auto ifound = gutils->invertedPointers.find(orig);
if (ifound != gutils->invertedPointers.end()) {
auto placeholder = cast<PHINode>(&*ifound->second);
gutils->replaceAWithB(placeholder, diffe);
gutils->erase(placeholder);
} else {
gutils->replaceAWithB(newcall, diffe);
gutils->erase(newcall);
if (!gutils->isConstantValue(&call)) {
setDiffe(&call, diffe, Builder2);
}
}
} else {
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
}
Expand Down
19 changes: 12 additions & 7 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2314,8 +2314,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
return AugmentedCachedFunctions.find(tup)->second;
}

void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB,
DIFFE_TYPE retType, ReturnType retVal) {
void createTerminator(TypeResults &TR, DiffeGradientUtils *gutils,
BasicBlock *oBB, DIFFE_TYPE retType, ReturnType retVal) {

BasicBlock *nBB = cast<BasicBlock>(gutils->getNewFromOriginal(oBB));
assert(nBB);
Expand All @@ -2341,8 +2341,14 @@ void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB,

toret =
nBuilder.CreateInsertValue(toret, gutils->getNewFromOriginal(ret), 0);
toret =
nBuilder.CreateInsertValue(toret, gutils->diffe(ret, nBuilder), 1);

if (TR.getReturnAnalysis().Inner0().isPossiblePointer()) {
toret = nBuilder.CreateInsertValue(
toret, gutils->invertPointerM(ret, nBuilder), 1);
} else {
toret =
nBuilder.CreateInsertValue(toret, gutils->diffe(ret, nBuilder), 1);
}
break;
}
case ReturnType::Void: {
Expand Down Expand Up @@ -3662,8 +3668,7 @@ Function *EnzymeLogic::CreateForwardDiff(
}

auto TRo = TA.analyzeFunction(oldTypeInfo);
bool retActive = TRo.getReturnAnalysis().Inner0().isPossibleFloat() &&
!todiff->getReturnType()->isVoidTy();
bool retActive = retType != DIFFE_TYPE::CONSTANT;

ReturnType retVal =
returnValue ? (retActive ? ReturnType::TwoReturns : ReturnType::Return)
Expand Down Expand Up @@ -3847,7 +3852,7 @@ Function *EnzymeLogic::CreateForwardDiff(
maker->visit(&*it);
}

createTerminator(gutils, &oBB, retType, retVal);
createTerminator(TR, gutils, &oBB, retType, retVal);
}

gutils->eraseFictiousPHIs();
Expand Down
61 changes: 61 additions & 0 deletions enzyme/test/Enzyme/ForwardMode/ptr-ret.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s

define dso_local noalias nonnull double* @_Z6toHeapd(double %x) {
entry:
%call = call noalias nonnull dereferenceable(8) i8* @_Znwm(i64 8)
%0 = bitcast i8* %call to double*
store double %x, double* %0, align 8
ret double* %0
}

declare dso_local nonnull i8* @_Znwm(i64)

define dso_local double @_Z6squared(double %x) {
entry:
%call = call double* @_Z6toHeapd(double %x)
%0 = load double, double* %call, align 8
%mul = fmul double %0, %x
ret double %mul
}

define dso_local double @_Z7dsquared(double %x) {
entry:
%call = call double (...) @_Z16__enzyme_fwddiffz(i8* bitcast (double (double)* @_Z6squared to i8*), double %x, double 1.000000e+00)
ret double %call
}

declare dso_local double @_Z16__enzyme_fwddiffz(...)



; CHECK: define dso_local double @_Z7dsquared(double %x)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call fast double @fwddiffe_Z6squared(double %x, double 1.000000e+00)
; CHECK-NEXT: ret double %0
; CHECK-NEXT: }

; CHECK: define internal double @fwddiffe_Z6squared(double %x, double %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %call = call double* @_Z6toHeapd(double %x)
; CHECK-NEXT: %0 = call { double*, double* } @fwddiffe_Z6toHeapd(double %x, double %"x'")
; CHECK-NEXT: %1 = extractvalue { double*, double* } %0, 1
; CHECK-NEXT: %2 = load double, double* %call, align 8
; CHECK-NEXT: %3 = load double, double* %1, align 8
; CHECK-NEXT: %4 = fmul fast double %3, %x
; CHECK-NEXT: %5 = fmul fast double %"x'", %2
; CHECK-NEXT: %6 = fadd fast double %4, %5
; CHECK-NEXT: ret double %6
; CHECK-NEXT: }

; CHECK: define internal { double*, double* } @fwddiffe_Z6toHeapd(double %x, double %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %call = call noalias nonnull dereferenceable(8) i8* @_Znwm(i64 8)
; CHECK-NEXT: %0 = call noalias nonnull dereferenceable(8) i8* @_Znwm(i64 8)
; CHECK-NEXT: %"'ipc" = bitcast i8* %0 to double*
; CHECK-NEXT: %1 = bitcast i8* %call to double*
; CHECK-NEXT: store double %x, double* %1, align 8
; CHECK-NEXT: store double %"x'", double* %"'ipc", align 8
; CHECK-NEXT: %2 = insertvalue { double*, double* } undef, double* %1, 0
; CHECK-NEXT: %3 = insertvalue { double*, double* } %2, double* %"'ipc", 1
; CHECK-NEXT: ret { double*, double* } %3
; CHECK-NEXT: }

0 comments on commit ffe46d5

Please sign in to comment.