Skip to content

Commit

Permalink
Handle StructReturn (rust-lang#338)
Browse files Browse the repository at this point in the history
* Handle sret

* add tests
  • Loading branch information
tgymnich authored Oct 11, 2021
1 parent 6e45ead commit 534beda
Show file tree
Hide file tree
Showing 7 changed files with 624 additions and 7 deletions.
62 changes: 55 additions & 7 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,13 @@ class Enzyme : public ModulePass {

Value *fn = CI->getArgOperand(0);

std::vector<DIFFE_TYPE> constants;
SmallVector<Value *, 2> args;

if (CI->paramHasAttr(0, Attribute::StructRet)) {
fn = CI->getArgOperand(1);
}

while (auto ci = dyn_cast<CastInst>(fn)) {
fn = ci->getOperand(0);
}
Expand All @@ -344,14 +351,43 @@ class Enzyme : public ModulePass {
auto FT = cast<Function>(fn)->getFunctionType();
assert(fn);

if (EnzymePrint)
llvm::errs() << "prefn:\n" << *fn << "\n";
bool sret = CI->paramHasAttr(0, Attribute::StructRet) ||
cast<Function>(fn)->hasParamAttribute(0, Attribute::StructRet);

std::vector<DIFFE_TYPE> constants;
SmallVector<Value *, 2> args;
IRBuilder<> Builder(CI);

unsigned truei = 0;
IRBuilder<> Builder(CI);
if (cast<Function>(fn)->hasParamAttribute(0, Attribute::StructRet)) {
Type *fnsrety = cast<PointerType>(FT->getParamType(0));

truei = 1;

const DataLayout &DL = CI->getParent()->getModule()->getDataLayout();
Type *Ty = fnsrety->getPointerElementType();
#if LLVM_VERSION_MAJOR >= 11
AllocaInst *primal = new AllocaInst(Ty, DL.getAllocaAddrSpace(), nullptr,
DL.getPrefTypeAlign(Ty));
#else
AllocaInst *primal = new AllocaInst(Ty, DL.getAllocaAddrSpace(), nullptr);
#endif

primal->insertBefore(CI);

Value *shadow;
if (mode == DerivativeMode::ForwardMode) {
shadow = CI->getArgOperand(0);
} else {
shadow = CI->getArgOperand(1);
sret = true;
}

args.push_back(primal);
args.push_back(shadow);
constants.push_back(DIFFE_TYPE::DUP_ARG);
}

if (EnzymePrint)
llvm::errs() << "prefn:\n" << *fn << "\n";

auto Arch =
llvm::Triple(
Expand All @@ -371,7 +407,7 @@ class Enzyme : public ModulePass {
llvm::Value *tape = nullptr;
bool tapeIsPointer = false;
int allocatedTapeSize = -1;
for (unsigned i = 1; i < CI->getNumArgOperands(); ++i) {
for (unsigned i = 1 + sret; i < CI->getNumArgOperands(); ++i) {
Value *res = CI->getArgOperand(i);

if (truei >= FT->getNumParams()) {
Expand Down Expand Up @@ -1053,7 +1089,9 @@ class Enzyme : public ModulePass {
}

if (!diffret->getType()->isEmptyTy() && !diffret->getType()->isVoidTy() &&
!CI->getType()->isEmptyTy() && !CI->getType()->isVoidTy()) {
!CI->getType()->isEmptyTy() &&
(!CI->getType()->isVoidTy() ||
CI->paramHasAttr(0, Attribute::StructRet))) {
if (diffret->getType() == CI->getType()) {
CI->replaceAllUsesWith(diffret);
} else if (mode == DerivativeMode::ReverseModePrimal) {
Expand All @@ -1071,6 +1109,16 @@ class Enzyme : public ModulePass {
llvm::errs() << *CI << " - " << *diffret << "\n";
assert(0 && " what");
}
} else if (CI->paramHasAttr(0, Attribute::StructRet)) {
Value *sret = CI->getArgOperand(0);

if (StructType *st = cast<StructType>(diffret->getType())) {
for (unsigned int i = 0;
i < diffret->getType()->getStructNumElements(); i++) {
Builder.CreateStore(Builder.CreateExtractValue(diffret, {i}),
Builder.CreateStructGEP(sret, i));
}
}
} else {

unsigned idxs[] = {0};
Expand Down
106 changes: 106 additions & 0 deletions enzyme/test/Enzyme/ForwardMode/sret.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
; RUN: if [ %llvmver -lt 12 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -early-cse -S | FileCheck %s ; fi


; #include <stdio.h>
; #include <array>

; using namespace std;

; extern array<double,3> __enzyme_fwddiff(void*, ...);

; array<double,3> square(double x) {
; return {x * x, x * x * x, x};
; }
; array<double,3> dsquare(double x) {
; // This returns the derivative of square or 2 * x
; return __enzyme_fwddiff((void*)square, x, 1.0);
; }
; int main() {
; printf("%f \n", dsquare(3)[0]);
; }


%"struct.std::array" = type { [3 x double] }

@.str = private unnamed_addr constant [5 x i8] c"%f \0A\00", align 1

define dso_local void @_Z6squared(%"struct.std::array"* noalias nocapture sret %agg.result, double %x) #0 {
entry:
%arrayinit.begin = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 0
%mul = fmul double %x, %x
store double %mul, double* %arrayinit.begin, align 8
%arrayinit.element = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 1
%mul2 = fmul double %mul, %x
store double %mul2, double* %arrayinit.element, align 8
%arrayinit.element3 = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 2
store double %x, double* %arrayinit.element3, align 8
ret void
}

define dso_local void @_Z7dsquared(%"struct.std::array"* noalias sret %agg.result, double %x) local_unnamed_addr #1 {
entry:
tail call void (%"struct.std::array"*, i8*, ...) @_Z16__enzyme_fwddiffPvz(%"struct.std::array"* sret %agg.result, i8* bitcast (void (%"struct.std::array"*, double)* @_Z6squared to i8*), double %x, double 1.000000e+00)
ret void
}

declare dso_local void @_Z16__enzyme_fwddiffPvz(%"struct.std::array"* sret, i8*, ...) local_unnamed_addr #2

define dso_local i32 @main() local_unnamed_addr #3 {
entry:
%ref.tmp = alloca %"struct.std::array", align 8
%0 = bitcast %"struct.std::array"* %ref.tmp to i8*
call void @llvm.lifetime.start.p0i8(i64 24, i8* nonnull %0) #6
call void (%"struct.std::array"*, i8*, ...) @_Z16__enzyme_fwddiffPvz(%"struct.std::array"* nonnull sret %ref.tmp, i8* bitcast (void (%"struct.std::array"*, double)* @_Z6squared to i8*), double 3.000000e+00, double 1.000000e+00)
%arrayidx.i.i = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %ref.tmp, i64 0, i32 0, i64 0
%1 = load double, double* %arrayidx.i.i, align 8
%call1 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([5 x i8], [5 x i8]* @.str, i64 0, i64 0), double %1)
call void @llvm.lifetime.end.p0i8(i64 24, i8* nonnull %0) #6
ret i32 0
}

declare dso_local i32 @printf(i8* nocapture readonly, ...) local_unnamed_addr #4

declare void @llvm.lifetime.start.p0i8(i64, i8* nocapture) #5

declare void @llvm.lifetime.end.p0i8(i64, i8* nocapture) #5

attributes #0 = { norecurse nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #1 = { uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #2 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #3 = { norecurse uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #4 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #5 = { argmemonly nounwind }
attributes #6 = { nounwind }


; CHECK: define dso_local void @_Z7dsquared(%"struct.std::array"* noalias sret %agg.result, double %x)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = alloca %"struct.std::array"
; CHECK-NEXT: call void @fwddiffe_Z6squared(%"struct.std::array"* %0, %"struct.std::array"* %agg.result, double %x, double 1.000000e+00)
; CHECK-NEXT: ret void
; CHECK-NEXT: }


; CHECK: define internal void @fwddiffe_Z6squared(%"struct.std::array"* noalias nocapture %agg.result, %"struct.std::array"* nocapture %"agg.result'", double %x, double %"x'") #0 {
; CHECK-NEXT: entry:
; CHECK-NEXT: %"arrayinit.begin'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %"agg.result'", i64 0, i32 0, i64 0
; CHECK-NEXT: %arrayinit.begin = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 0
; CHECK-NEXT: %mul = fmul double %x, %x
; CHECK-NEXT: %0 = fmul fast double %"x'", %x
; CHECK-NEXT: %1 = fadd fast double %0, %0
; CHECK-NEXT: store double %mul, double* %arrayinit.begin, align 8
; CHECK-NEXT: store double %1, double* %"arrayinit.begin'ipg", align 8
; CHECK-NEXT: %"arrayinit.element'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %"agg.result'", i64 0, i32 0, i64 1
; CHECK-NEXT: %arrayinit.element = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 1
; CHECK-NEXT: %mul2 = fmul double %mul, %x
; CHECK-NEXT: %2 = fmul fast double %1, %x
; CHECK-NEXT: %3 = fmul fast double %"x'", %mul
; CHECK-NEXT: %4 = fadd fast double %2, %3
; CHECK-NEXT: store double %mul2, double* %arrayinit.element, align 8
; CHECK-NEXT: store double %4, double* %"arrayinit.element'ipg", align 8
; CHECK-NEXT: %"arrayinit.element3'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %"agg.result'", i64 0, i32 0, i64 2
; CHECK-NEXT: %arrayinit.element3 = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 2
; CHECK-NEXT: store double %x, double* %arrayinit.element3, align 8
; CHECK-NEXT: store double %"x'", double* %"arrayinit.element3'ipg", align 8
; CHECK-NEXT: ret void
; CHECK-NEXT: }
106 changes: 106 additions & 0 deletions enzyme/test/Enzyme/ForwardMode/sret12.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
; RUN: if [ %llvmver -ge 12 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -early-cse -S | FileCheck %s ; fi


; #include <stdio.h>
; #include <array>

; using namespace std;

; extern array<double,3> __enzyme_fwddiff(void*, ...);

; array<double,3> square(double x) {
; return {x * x, x * x * x, x};
; }
; array<double,3> dsquare(double x) {
; // This returns the derivative of square or 2 * x
; return __enzyme_fwddiff((void*)square, x, 1.0);
; }
; int main() {
; printf("%f \n", dsquare(3)[0]);
; }


%"struct.std::array" = type { [3 x double] }

@.str = private unnamed_addr constant [5 x i8] c"%f \0A\00", align 1

define dso_local void @_Z6squared(%"struct.std::array"* noalias nocapture sret(%"struct.std::array") align 8 %agg.result, double %x) #0 {
entry:
%arrayinit.begin = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 0
%mul = fmul double %x, %x
store double %mul, double* %arrayinit.begin, align 8
%arrayinit.element = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 1
%mul2 = fmul double %mul, %x
store double %mul2, double* %arrayinit.element, align 8
%arrayinit.element3 = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 2
store double %x, double* %arrayinit.element3, align 8
ret void
}

define dso_local void @_Z7dsquared(%"struct.std::array"* noalias sret(%"struct.std::array") align 8 %agg.result, double %x) local_unnamed_addr #1 {
entry:
tail call void (%"struct.std::array"*, i8*, ...) @_Z16__enzyme_fwddiffPvz(%"struct.std::array"* sret(%"struct.std::array") align 8 %agg.result, i8* bitcast (void (%"struct.std::array"*, double)* @_Z6squared to i8*), double %x, double 1.000000e+00)
ret void
}

declare dso_local void @_Z16__enzyme_fwddiffPvz(%"struct.std::array"* sret(%"struct.std::array") align 8, i8*, ...) local_unnamed_addr #2

define dso_local i32 @main() local_unnamed_addr #3 {
entry:
%ref.tmp = alloca %"struct.std::array", align 8
%0 = bitcast %"struct.std::array"* %ref.tmp to i8*
call void @llvm.lifetime.start.p0i8(i64 24, i8* nonnull %0) #6
call void (%"struct.std::array"*, i8*, ...) @_Z16__enzyme_fwddiffPvz(%"struct.std::array"* nonnull sret(%"struct.std::array") align 8 %ref.tmp, i8* bitcast (void (%"struct.std::array"*, double)* @_Z6squared to i8*), double 3.000000e+00, double 1.000000e+00)
%arrayidx.i.i = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %ref.tmp, i64 0, i32 0, i64 0
%1 = load double, double* %arrayidx.i.i, align 8
%call1 = call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([5 x i8], [5 x i8]* @.str, i64 0, i64 0), double %1)
call void @llvm.lifetime.end.p0i8(i64 24, i8* nonnull %0) #6
ret i32 0
}

declare dso_local noundef i32 @printf(i8* nocapture noundef readonly, ...) local_unnamed_addr #4

declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #5

declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #5

attributes #0 = { nofree norecurse nounwind uwtable willreturn writeonly mustprogress "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #1 = { uwtable mustprogress "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #2 = { "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #3 = { norecurse uwtable mustprogress "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #4 = { nofree nounwind "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #5 = { argmemonly nofree nosync nounwind willreturn }
attributes #6 = { nounwind }


; CHECK: define dso_local void @_Z7dsquared(%"struct.std::array"* noalias sret(%"struct.std::array") align 8 %agg.result, double %x)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = alloca %"struct.std::array"
; CHECK-NEXT: call void @fwddiffe_Z6squared(%"struct.std::array"* %0, %"struct.std::array"* %agg.result, double %x, double 1.000000e+00)
; CHECK-NEXT: ret void
; CHECK-NEXT: }


; CHECK: define internal void @fwddiffe_Z6squared(%"struct.std::array"* noalias nocapture align 8 %agg.result, %"struct.std::array"* nocapture %"agg.result'", double %x, double %"x'") #0 {
; CHECK-NEXT: entry:
; CHECK-NEXT: %"arrayinit.begin'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %"agg.result'", i64 0, i32 0, i64 0
; CHECK-NEXT: %arrayinit.begin = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 0
; CHECK-NEXT: %mul = fmul double %x, %x
; CHECK-NEXT: %0 = fmul fast double %"x'", %x
; CHECK-NEXT: %1 = fadd fast double %0, %0
; CHECK-NEXT: store double %mul, double* %arrayinit.begin, align 8
; CHECK-NEXT: store double %1, double* %"arrayinit.begin'ipg", align 8
; CHECK-NEXT: %"arrayinit.element'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %"agg.result'", i64 0, i32 0, i64 1
; CHECK-NEXT: %arrayinit.element = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 1
; CHECK-NEXT: %mul2 = fmul double %mul, %x
; CHECK-NEXT: %2 = fmul fast double %1, %x
; CHECK-NEXT: %3 = fmul fast double %"x'", %mul
; CHECK-NEXT: %4 = fadd fast double %2, %3
; CHECK-NEXT: store double %mul2, double* %arrayinit.element, align 8
; CHECK-NEXT: store double %4, double* %"arrayinit.element'ipg", align 8
; CHECK-NEXT: %"arrayinit.element3'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %"agg.result'", i64 0, i32 0, i64 2
; CHECK-NEXT: %arrayinit.element3 = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 2
; CHECK-NEXT: store double %x, double* %arrayinit.element3, align 8
; CHECK-NEXT: store double %"x'", double* %"arrayinit.element3'ipg", align 8
; CHECK-NEXT: ret void
; CHECK-NEXT: }
Loading

0 comments on commit 534beda

Please sign in to comment.