Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derivatives of function imported from module not working properly #173

Open
g-bauer opened this issue Aug 20, 2024 · 4 comments
Open

Derivatives of function imported from module not working properly #173

g-bauer opened this issue Aug 20, 2024 · 4 comments

Comments

@g-bauer
Copy link

g-bauer commented Aug 20, 2024

When importing a function (with autodiff macro) from a module, the derivatives are missing. I created a minimal example here.

# identical function as below, but defined in lib.rs
[src/main.rs:34:5] enzyme_y1_lib = (
    4.497780053946161,
    0.0, # <---
)
# identical function as above, but defined in main.rs
[src/main.rs:35:5] enzyme_y1f = (
    4.497780053946161,
    4.05342789389862,
)

Meta

rustc --version --verbose:

rustc 1.82.0-nightly (86dedf7dc 2024-08-16)
binary: rustc
commit-hash: 86dedf7dc5b63661998a038c726033ad92c2d40e
commit-date: 2024-08-16
host: x86_64-unknown-linux-gnu
release: 1.82.0-nightly
LLVM version: 19.1.0
@g-bauer g-bauer changed the title Function imported from module not working properly Derivatives of function imported from module not working properly Aug 20, 2024
@ZuseZ4
Copy link
Member

ZuseZ4 commented Aug 21, 2024

Even setting codegen-units=1 doesn't fix this, so I'll have to look into how rustc compiles libs.rs+main.rs here.
For the performance, I'm deeply confused.
To start, your function lowers to this IR, which looks reasonable:

   18 ; enzyme_playground::_f1
   17 ; Function Attrs: mustprogress nofree noinline norecurse nosync nounwind nonlazybind sanitize_hwaddress willreturn memory(none) uwtable
   16 define noundef double @_ZN17enzyme_playground3_f117h5f4eabd9c1eee14aE(double noundef %x) unnamed_addr #1 {
   15 start:
   14   %0 = tail call double @llvm.exp.f64(double %x)
   13   %1 = tail call double @llvm.sin.f64(double %x)
   12   %2 = tail call double @llvm.powi.f64.i32(double %1, i32 3)
   11   %3 = tail call double @llvm.cos.f64(double %x)
   10   %4 = tail call double @llvm.powi.f64.i32(double %3, i32 3)
    9   %_4.i = fadd double %2, %4
    8   %5 = tail call double @llvm.sqrt.f64(double %_4.i)
    7   %_0.i = fdiv double %0, %5
    6   ret double %_0.i
    5 }

I was also right in remembering that I have one more indirection, see this:

; Function Attrs: noinline nonlazybind uwtable
 define internal { double, double } @_ZN17enzyme_playground4dfff17haf2f79cc33487406E(double noundef %0, double noundef %1) unnamed_addr #2 {
  5006     %3 = call { double, double } @fwddiffe_f1(double %0, double %1)
   ret { double, double } %3
 }                 

But once you look inside the enzyme generated function, it calls one function per original instruction. I.e. we have 5 or 6 function calls beyond the extra indirection mentioned above. Of course, for such simple code that will completely kill the performance. Luckily enzyme has some debug flags described here: https://enzyme.mit.edu/index.fcgi/rust/Debugging.html
So now we run RUSTFLAGS="-Z autodiff=PrintModAfterEnzyme,Inline" cargo +enzyme build --release &> mod.ll
and get something much more sensible:

; Function Attrs: mustprogress noinline nonlazybind willreturn uwtable
define internal { double, double } @fwddiffe_f2(double noundef "enzyme_type"="{[-1]:Float@double}" %0, double "enzyme_type"="{[-1]:Float@double}" %1) unnamed_addr #122 personality ptr @rust_eh_personality {
  call void @llvm.experimental.noalias.scope.decl(metadata !45833) #127
  %3 = call noundef double @llvm.exp.f64(double %0) #127
  %4 = call fast double @llvm.exp.f64(double %0)
  %5 = fmul fast double %1, %4
  call void @llvm.experimental.noalias.scope.decl(metadata !45836) #127
  %6 = call noundef double @llvm.sin.f64(double %0) #127
  %7 = call fast double @llvm.cos.f64(double %0)
  %8 = fmul fast double %1, %7
  call void @llvm.experimental.noalias.scope.decl(metadata !45839) #127
  %9 = call noundef double @llvm.powi.f64.i32(double %6, i32 3) #127
  %10 = fcmp fast oeq double %8, 0.000000e+00
  %11 = and i1 false, %10
  %12 = or i1 false, %11
  %13 = call fast double @llvm.powi.f64.i32(double %6, i32 2)
  %14 = fmul fast double 3.000000e+00, %13
  %15 = fmul fast double %8, %14
  %16 = select fast i1 %12, double 0.000000e+00, double %15
  call void @llvm.experimental.noalias.scope.decl(metadata !45842) #127
  %17 = call noundef double @llvm.cos.f64(double %0) #127
  %18 = call fast double @llvm.sin.f64(double %0)
  %19 = fneg fast double %18
  %20 = fmul fast double %1, %19
  call void @llvm.experimental.noalias.scope.decl(metadata !45845) #127
  %21 = call noundef double @llvm.powi.f64.i32(double %17, i32 3) #127
  %22 = fcmp fast oeq double %20, 0.000000e+00
  %23 = and i1 false, %22
  %24 = or i1 false, %23
  %25 = call fast double @llvm.powi.f64.i32(double %17, i32 2)
  %26 = fmul fast double 3.000000e+00, %25
  %27 = fmul fast double %20, %26
  %28 = select fast i1 %24, double 0.000000e+00, double %27
  %29 = fadd double %9, %21
  %30 = fadd fast double %16, %28
  call void @llvm.experimental.noalias.scope.decl(metadata !45848) #127
  %31 = call noundef double @llvm.sqrt.f64(double %29) #127
  %32 = fcmp fast ueq double %29, 0.000000e+00
  %33 = call fast double @llvm.sqrt.f64(double %29) #128
  %34 = fmul fast double 2.000000e+00, %33
  %35 = fdiv fast double %30, %34
  %36 = select fast i1 %32, double 0.000000e+00, double %35
  %37 = fdiv double %3, %31
  %38 = fmul fast double %5, %31
  %39 = fmul fast double %36, %3
  %40 = fsub fast double %38, %39
  %41 = fmul fast double %31, %31
  %42 = fdiv fast double %40, %41
  %43 = insertvalue { double, double } undef, double %37, 0
  %44 = insertvalue { double, double } %43, double %42, 1
  ret { double, double } %44
}

Benchmarking times aren't affected though, mine were better for enzyme from the beginning (415 instead of 500 which you have on your repo), but they are the same for me with and without the flag.

Enzyme: 1st order/forward
                        time:   [415.55 ps 415.68 ps 415.86 ps]
num-dual: 1st order/Dual64
                        time:   [308.20 ps 308.84 ps 309.84 ps]

Now, that we know that Enzyme's indirection is likely the issue, let's handycap num-dual and bench indirection:

+pub fn indirection<D: DualNum<f64>>(x: D) -> D {
+    f1(x)
+}
+
+#[inline(never)]
 pub fn f1<D: DualNum<f64>>(x: D) -> D {
     x.exp() / (x.sin().powi(3) + x.cos().powi(3)).sqrt()
 }

Now we get:

num-dual: 1st order/Dual64
                        time:   [408.21 ps 408.69 ps 409.33 ps]
                        change: [+30.526% +31.881% +33.197%] (p = 0.00 < 0.05)
                        Performance has regressed.

And indeed, it's down to Enzyme level. So in summary enzyme currently always has one more indirection because I didn't bother with cleaning up the llvm-ir enough. I never noticed because it's easily covered by every slightly more complex operation, but here we just have 5 simple operations, due to which it actually has an effect. I can't promise to fix it too soon since it likely won't be measurable beyond toy examples, but I'll leave it open as a reminder. It would also be an easy way to get started, it shouldn't be too hard for a new contributor.

@ZuseZ4
Copy link
Member

ZuseZ4 commented Aug 21, 2024

I pushed ab54904 to remove one layer of indirection, but interesting enough it had no performance impact. I'll look if I can inline even the call to the differentiated function, that might help. In the meantime, please feel free to post the llvm-ir of your function, cargo has a flag for that. Maybe we can spot the difference that way?

@ZuseZ4
Copy link
Member

ZuseZ4 commented Aug 21, 2024

@g-bauer Eventually this could be the reason why we're slower. There is a whole discussion on correctness in AD here: EnzymeAD/Enzyme#1295
Do you have special handling in your tool for sqrt(0)? I tend to not adjust the default behaviour even if it has a small performance overhead, since for non-toy examples the perf benefits of LLVM based AD should easily cover this perf overhead. Do you have any larger benchmarks on which we could compare?

@g-bauer
Copy link
Author

g-bauer commented Aug 21, 2024

We don't have special treatment (see here). Taking the derivative of sqrt(0.0) will return NaN. But I don't think that's the issue here. Changing to a different operation (pow, ln, ...) doesn't change the results in the benchmark on my machine.

I'll add a longer example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants