Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ChainRules rules #40

Open
sethaxen opened this issue Jul 24, 2020 · 12 comments
Open

Add ChainRules rules #40

sethaxen opened this issue Jul 24, 2020 · 12 comments

Comments

@sethaxen
Copy link

From Slack:
@sethaxen:

Does ExponentialUtilities.jl play well with AD packages, in particular Zygote?

@ChrisRackauckas:

not fully with Zygote
it'll need adjoints
since it's doing a lot of scalar stuff
it's writing the kernels directly
the adjoints are easy though

I in particular need adjoints for expv. Zygote currently has an adjoint rule for exp(::AbstractMatrix) and exp(::Hermitian) using the eigendecomposition. I imagine though there's a better way to implement the adjoints for expv by looking at the underlying algorithm (I have not).

@ChrisRackauckas
Copy link
Member

There should be ways to do this without defining the Jacobian. expv is the solution to the linear ODE, so the adjoint of the ODE can be should be able to be used to derive the expression in terms of the adjoint, which IIRC should just be:

du = Au -> dlambda -> lambda'*A

which means the adjoint should just be expv(A,delta').

@sethaxen
Copy link
Author

For λ = expv(t, A, u) and adjoint of λ, Δλ, the adjoint of u should I think be ∂u = expv(t, A', Δλ), which is quite nice.
We also need the adjoints ∂t and ∂A, which will take more thought.

@sethaxen
Copy link
Author

sethaxen commented Jul 25, 2020

Especially since A doesn't even need to be a matrix, right? I don't think we'll be able to support all types of A for a custom adjoint, just AbstractMatrixes.

@ChrisRackauckas
Copy link
Member

Yeah, the difficult thing will be supporting something that's not concrete, since then it can't adjoint. But then that's just defined as the reverse mode of the function f(u) = A*u, so I think it can work out, it'll just be more complicated in code.

Those again would come from this derivation. You might want to read https://diffeq.sciml.ai/stable/extras/sensitivity_math/ or the supplemental of https://arxiv.org/abs/2001.04385 . Specifically, the ∂A term is given by an integral over the Legrange multiplier term. Coincidentally, the phiv values used in the exponential integrators are these integrals, so the adjoint can probably be written as just a calculation of phi_1. I think it's like phiv(t, A', Δλ) + reversemode(A) kind of thing (in pseudocode, off the top of my head so maybe missing a detail somewhere).

∂t is easy in this interpretation: λ = expv(t, A, u) = exp(t*A)*u is equal to λ=Aλ where λ(0)=u and solve to t, so the derivative of the solution w.r.t. t is just A (or in reverse-mode, maybe A').

Again, all might be missing a detail since I'm doing it quickly, but that should be the gist of it.

@sethaxen
Copy link
Author

Thanks! That should be enough to get me started. I'll probably tackle this in a few months if no one else does before then (unless I find some time early).

@sethaxen
Copy link
Author

Working on this now and have some follow-up questions.

Specifically, the ∂A term is given by an integral over the Legrange multiplier term. Coincidentally, the phiv values used in the exponential integrators are these integrals, so the adjoint can probably be written as just a calculation of phi_1. I think it's like phiv(t, A', Δλ) + reversemode(A) kind of thing (in pseudocode, off the top of my head so maybe missing a detail somewhere).

I've spent some time working through the provided references and still haven't yet comprehended this comment. What is reversemode(A) here? By phiv(t, A', Δλ) do you mean phiv(t, A', Δλ, 1)[:, 2], which computes I believe \phi_1(A') Δλ? This would compute an adjoint of the same dimension as v, not a matrix.

@ChrisRackauckas
Copy link
Member

Hmm, I guess it doesn't use the phi_1. It is the first integral of the term so I'm a little surprised it doesn't show up.

@sethaxen
Copy link
Author

sethaxen commented Dec 20, 2020

Okay, I think I worked something out for forward mode at least. The pushforward of u = expv(t, A, u_0) is (using slide 5 of http://www1.maths.leeds.ac.uk/~jitse/scicade09.pdf):
Δu = A \phi_0(tA) u_0 Δt + (\phi_0(tA) Δu_0 + \sum_{i=1}^\infty t^i \phi_i(tA) ΔA A^(i-1) u_0),
the part in parentheses being the solution to the ODE
Δu′ = A Δu + ΔA u.
Perhaps there's some way to simplify that hideous sum term. Still need to work out the corresponding reverse mode.

@sethaxen
Copy link
Author

Following up on @ChrisRackauckas's point, we can indeed compute the adjoint of A by solving an ODE in reverse. A working prototype here:
https://gist.github.com/sethaxen/4071b401b9b4ff4f5421136cec2fa7da/dd914b79d465d8653b1674cbc466f5a29d95fbae#file-expv_chainrules-jl-L64-L77

I haven't worked out how to solve this ODE using just the functions in this package; currently I require OrdinaryDiffEq. This does what I need to right now, so I'll put #51 on hold until I work out something efficient I can do using just this package.

@sethaxen
Copy link
Author

Another way to compute the adjoint of A comes from https://doi.org/10.1109/TAC.1978.1101743.
Let w = expv(t, A, v), Δw be the adjoint of w, and ∂v = expv(t, A', Δw) be the pulled back adjoint of v.
The adjoint of A is the solution to the integral int_0^t exp(s A') Δw w' exp(-s A') ds.
Define the block-triangular matrix D = [-A' ∂v*w'; zero(A) -A']. Then the upper right block of exp(t * D) is the adjoint of A. This is fine for small dense A but is otherwise very inefficient, so this doesn't seem useful.

@sethaxen
Copy link
Author

Here's where I landed on this. The adjoint for A will be computed by hand-deriving the pullback through exp and arnoldi/lanczos, The former will be added to ChainRules (JuliaDiff/ChainRules.jl#331). I locally have an implementation of the latter that requires no checkpointing.

For matrix n × n A, the final step of the pullback for arnoldi is the product of an n × m matrix and the adjoint of another n × m matrix, where m is the dimension of the Krylov subspace. For dense A, this is just a matmul, but for huge sparse A, we would need to know its sparsity pattern to avoid creating a huge dense matrix and instead only compute certain dot products of columns.

We need a function like outer_sparse!(∂A, x::AbstractVecOrMat, y::AbstractVecOrMat), where ∂A is a differential type of A (either Composite{typeof(A)} or an AbstractMatrix) that does this. We can implement such a function for all AbstractMatrix types in base Julia and define the rrule only for those types, wrapping an expv_rev that has no type constraints. Then an implementer of a custom operator can overload outer_sparse! for their operator and define an rrule wrapping expv_rev. Unfortunately this would require the array package to require ExponentialUtilities or a user to commit type piracy.

@stevengj
Copy link

stevengj commented Jan 11, 2024

(Note that this block-triangular rule is a special case of an algorithm to differentiate matrix functions by Mathias in 1996, as discussed in JuliaDiff/ChainRules.jl#764)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants