Skip to content

Commit df488bd

Browse files
Handle Enzyme tape size of zero (rust-lang#368)
1 parent 36f3923 commit df488bd

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

Diff for: enzyme/Enzyme/Enzyme.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,12 @@ class Enzyme : public ModulePass {
978978
? aug->fn->getReturnType()
979979
: cast<StructType>(aug->fn->getReturnType())
980980
->getElementType(tapeIdx);
981+
} else {
982+
if (sizeOnly) {
983+
CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0, false));
984+
CI->eraseFromParent();
985+
return true;
986+
}
981987
}
982988
if (sizeOnly) {
983989
auto size = DL.getTypeSizeInBits(tapeType) / 8;

Diff for: enzyme/test/Enzyme/ReverseMode/splitSize5.ll

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: noinline nounwind readnone uwtable
4+
define double @tester(double* %x) {
5+
entry:
6+
%gep = getelementptr double, double* %x, i32 1
7+
%y = load double, double* %x
8+
%z = load double, double* %gep
9+
%res = fadd fast double %y, %z
10+
ret double %res
11+
}
12+
13+
define void @test_derivative(double* %x, double* %dx) {
14+
entry:
15+
%size = call i64 (double (double*)*, ...) @__enzyme_augmentsize(double (double*)* nonnull @tester, metadata !"enzyme_dup")
16+
%cache = alloca i8, i64 %size, align 1
17+
call void (double (double*)*, ...) @__enzyme_augmentfwd(double (double*)* nonnull @tester, metadata !"enzyme_allocated", i64 %size, metadata !"enzyme_tape", i8* %cache, double* %x, double* %dx)
18+
tail call void (double (double*)*, ...) @__enzyme_reverse(double (double*)* nonnull @tester, metadata !"enzyme_allocated", i64 %size, metadata !"enzyme_tape", i8* %cache, double* %x, double* %dx)
19+
ret void
20+
}
21+
22+
; Function Attrs: nounwind
23+
declare void @__enzyme_augmentfwd(double (double*)*, ...)
24+
declare i64 @__enzyme_augmentsize(double (double*)*, ...)
25+
declare void @__enzyme_reverse(double (double*)*, ...)
26+
27+
; CHECK: define void @test_derivative(double* %x, double* %dx)
28+
; CHECK-NEXT: entry:
29+
; CHECK-NEXT: %0 = call fast double @augmented_tester(double* %x, double* %dx)
30+
; CHECK-NEXT: call void @diffetester(double* %x, double* %dx, double 1.000000e+00)
31+
; CHECK-NEXT: ret void
32+
; CHECK-NEXT:}

0 commit comments

Comments
 (0)