Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Enzyme dispatches compatible with closures #339

Closed
ChrisRackauckas opened this issue Jun 27, 2024 · 20 comments · Fixed by #375 or #407
Closed

Make Enzyme dispatches compatible with closures #339

ChrisRackauckas opened this issue Jun 27, 2024 · 20 comments · Fixed by #375 or #407
Labels
backend Related to one or more autodiff backends bug Something isn't working

Comments

@ChrisRackauckas
Copy link
Member

In the Enzyme setups https://github.com/gdalle/DifferentiationInterface.jl/blob/main/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L13 it looks like you're using the raw f. This omits the handling of any memory associated with caches, particularly within closures. To fix this is rather straightforward though, you can just copy SciMLSensitivity. You just do a duplicated on the f https://github.com/SciML/SciMLSensitivity.jl/blob/master/src/derivative_wrappers.jl#L697 where the duplicated part is just an f_cache = Enzyme.make_zero(f) copy. To make this safe for repeated application, you need to add a call Enzyme.make_zero!(f_cache) so its duplicated values are always zero if you reuse it.

@gdalle gdalle added bug Something isn't working backend Related to one or more autodiff backends labels Jun 27, 2024
@gdalle
Copy link
Member

gdalle commented Jun 28, 2024

@wsmoses said this was probably a bad idea due to performance degradation, so I'm leaving the PR #341 closed for now. Are there other solutions?

@ChrisRackauckas
Copy link
Member Author

Well the other option is incorrectness or just erroring if caches are used, I don't see how that's better?

@wsmoses
Copy link

wsmoses commented Jun 28, 2024

I mean honestly this is where activity info/multi arg is critical.

if you have a closure (which is is required by DI atm), then you'll end up differentiating every var in the original fn. So if you have something like

NN = complex neural network
DI.gradient(AutoEnzyme(), x->NN() + x, 3.1)

you'll now be forced to AD the entire neural network as opposed to the one scalar. In this case leading an O(1) derivative being unboundedly worse. Without the ability to handle multiple args/activity, DI would be forced to AD through the whole NN if the closure were marked active.

@wsmoses
Copy link

wsmoses commented Jun 28, 2024

Frankly, this is where I'd say it makes sense for DI to figure out how it and/or AD.jl wants to handle multiple args, use direct Enzyme autodiff calls which don't have such limitations for now, revisiting this question later.

@gdalle
Copy link
Member

gdalle commented Jun 28, 2024

I'm slowly getting a clearer picture of how I can pull it off. But the initial plan was for AbstractDifferentiation to handle multiple arguments, so I wanna wait for @mohamed82008's approval before I dive into it within DI.

@ChrisRackauckas
Copy link
Member Author

Even if DI handles multiple arguments though, you'd still want to duplicate the function because if you don't handle any enclosed caches correctly you can get incorrect derivatives, so I don't see why this would wait. Indeed the downside is that you always have to assume that all caches can be differentiated, and this is then a good reason to allow for multiple arguments so you can Const some, but my point is that if we want DI to actually be correct then we do need to enforce the differentiation of enclosed variables carries forward their derivative values.

@ChrisRackauckas
Copy link
Member Author

It at least needs to be an option, AutoEnzyme(duplicate_function = Val(true)) by default, but can be Val(false) as an optimization if someone wants to forcibly Const all enclosed values (at their own risk). If someone has no enclosed values there's no overhead, and if they are non-const then the default is correct, so it's just a performance optimization so I'd leave that as a user toggle. Adding that to ADTypes would be good for SciMLSensitivity as well as we'd do the same in implementation.

@wsmoses
Copy link

wsmoses commented Jun 28, 2024

My point about support for multiple arguments and/or activity, is that they would potentially remedy the performance issue in my example.

if DI supported specifying the function as const/duplicated [aka activity] the problem is trivially remedied.

In the alternative, if multiple arguments were supported [perhaps with a Const input], you could pass the NN and/or closure data in it and again avoid the issue.

@ChrisRackauckas
Copy link
Member Author

I don't disagree with that. My point though is that even if DI makes all of the inputs arguments, the default activity on a function would likely be const unless the documentation showed people how to do this. I don't think that's the right default for DI since then many common Julia functions would give wrong values. You'd basically have to say, don't pass f, the interface is Duplicated(f, make_zero(f)). My point is that shouldn't be left to the user of DI who should expect that the simple thing is correct, and if DI.gradient(f, x) is wrong because they need to DI.gradient(Duplicated(f, make_zero(f)), x) otherwise they drop derivatives on enclosed caches, I would think something has gone wrong with the interface. My suggestion is to just via AutoEnzyme make the assumption that's required, which is still optimal in the case that there are no caches, but yes is effectively a safety copy done to make caching functions work out of the box, but with an option to turn it off at their own risk.

But also, DI shouldn't wait until multi-arg activities are supported before doing any of this. Otherwise it will have issues with user-written closures until multi-arg activities, which arguably is a pretty nasty bug that requires a hotfix. It does mean that yes constants enclosed in functions will slow things down a bit because you'll differentiate more than you need to, but it also means that enclosed cache variables will correctly propagate derivatives which is more important to a high level interface.

I didn't test this exactly, but I would think an MWE would be as simple as:

a = [1.0]
function f(x)
  a[1] = 1.0
  a[1] += x
  a[1]^2
end

would give an incorrect derivative with DI without this, which to me is a red flag that needs to be fixed. And then we can argue when the multi-arg form comes whether the user needs to enable the fix or whether the fix comes enabled by default, but I don't think we should wait to make this work.

And to be clear, I don't think Enzyme's interface should do this, but Enzyme is a much lower level utility targeting a different level of user.

@gdalle
Copy link
Member

gdalle commented Jul 1, 2024

I tend to agree with Chris on this one. Until I add activities or multiple arguments, better safe and slow than fast and wrong.

@wsmoses
Copy link

wsmoses commented Jul 2, 2024

I see what you're saying, but I still feel like this is an edge case that is more likely to cause problems for users than fixes.

In particular, the only sort of case where this is needed is where you read and write to a captured buffer. With the exception of preallocation tools code, this is immensely rare in Julia functions you want to AD [largely due to non-mutation, especially non-mutation of closures].

However, by marking the entire closure as duplicated, you now need enzyme to successfully differentiate all closure operations, including those where this read and write to capured buffer doesn't apply. If there's a function currently unhandled by Enzyme you'll error with the duplicated fn, whereas marking it const would succeed.

To be clear, I see the arguments for both sides of this, but I'm wondering what is the better trade off to make.

@wsmoses
Copy link

wsmoses commented Jul 2, 2024

Honestly, given that I'm doubtful of much code outside of preallocationtools that would have this apply, I wonder if it make sense to just add a preallocationtools mode to DI [which may be separately useful in its own right]

@ChrisRackauckas
Copy link
Member Author

In particular, the only sort of case where this is needed is where you read and write to a captured buffer. With the exception of preallocation tools code, this is immensely rare in Julia functions you want to AD [largely due to non-mutation, especially non-mutation of closures].

That's not really the case though. It's not rare. It's actually very common and explicitly mentioned in the documentation of many packages and tutorials that one should write non-allocating code. Here is one of many examples of that:

https://docs.sciml.ai/DiffEqDocs/stable/tutorials/faster_ode_example/#Example-Accelerating-Linear-Algebra-PDE-Semi-Discretization

Such functions are made to be fully mutating and non-allocating, and also fully type-stable, and so perfectly within the realm of Enzyme. And these functions will not error but give the wrong answer if the closure is not duplicated, which is not the nicest behavior.

I think you're thinking specifically about Flux using functors where it's effectively allocating type-unstable functional code carrying around parameters in its objects which may not need to be differentiated. Flux is the weird one, not everything else. I actually can't think of another library that is engineered similarly to Flux, while most scientific models, PDE solvers, etc. are engineered similarly to the example I have up there where pre-allocated buffers are either passed around or enclosed and then used for getting a allocation-free runtime. And in any case, I'd argue it should be the Flux example to opt-out of duplicating the closure as a performance improvement, not the scientific models, PDE solvers, etc. opting into duplicating the function in order to ensure they get the right gradient value on repeated applications with caches.

@wsmoses
Copy link

wsmoses commented Jul 2, 2024 via email

@ChrisRackauckas
Copy link
Member Author

Ayu = zeros(N, N)
uAx = zeros(N, N)
Du = zeros(N, N)
Ayv = zeros(N, N)
vAx = zeros(N, N)
Dv = zeros(N, N)
function gm3!(dr, r, p, t)
    a, α, ubar, β, D1, D2 = p
    u = @view r[:, :, 1]
    v = @view r[:, :, 2]
    du = @view dr[:, :, 1]
    dv = @view dr[:, :, 2]
    mul!(Ayu, Ay, u)
    mul!(uAx, u, Ax)
    mul!(Ayv, Ay, v)
    mul!(vAx, v, Ax)
    @. Du = D1 * (Ayu + uAx)
    @. Dv = D2 * (Ayv + vAx)
    @. du = Du + a * u * u ./ v + ubar - α * u
    @. dv = Dv + a * u * u - β * v
end
prob = ODEProblem(gm3!, r0, (0.0, 0.1), p)
@btime solve(prob, Tsit5());

@wsmoses
Copy link

wsmoses commented Jul 2, 2024 via email

@gdalle
Copy link
Member

gdalle commented Jul 15, 2024

See SciML/SciMLBenchmarks.jl#988 for a more involved discussion

@gdalle
Copy link
Member

gdalle commented Jul 20, 2024

The first ingredient of the solution is available in the latest release of ADTypes with AutoEnzyme(constant_function=true/false). Now it's on me to implement both variants here

@gdalle gdalle linked a pull request Jul 22, 2024 that will close this issue
@gdalle gdalle reopened this Jul 25, 2024
@gdalle
Copy link
Member

gdalle commented Jul 25, 2024

@willtebbutt what kind of assumptions does Tapir make vis-a-vis constant functions?

@willtebbutt
Copy link
Member

At present, Tapir.jl assumes that all arguments are active, and differentiates through everything. Consequently I'm reasonably confident that there's nothing here that is relevant to Tapir.jl.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend Related to one or more autodiff backends bug Something isn't working
Projects
None yet
4 participants