Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rules for the matrix exponential #331

Closed
sethaxen opened this issue Dec 23, 2020 · 6 comments · Fixed by #351
Closed

Add rules for the matrix exponential #331

sethaxen opened this issue Dec 23, 2020 · 6 comments · Fixed by #351

Comments

@sethaxen
Copy link
Member

sethaxen commented Dec 23, 2020

TL/DR a better approach for rules for matrix exponential of dense matrices.

We should add rules for exp(::StridedMatrix) would supersede Zygote's.

The right rule for exp would be none, i.e. to AD through exp, which uses the scaling and squaring algorithm, but this uses mutation, which Zygote doesn't support. it also is limited to BlasFloat, so ForwardDiff and I don't think any of the operator overloading AD's can handle it. All that to say, we should have the rules.

Zygote currently uses the power series pullback using eigendecomposition for the backward pass. The eigendecomposition is not an accurate way to compute the exponential in general (https://epubs.siam.org/doi/10.1137/S00361445024180) (it's fine for hermitian matrices though, hence the exp(::Hermitian) overload in LinearAlgebra). Zygote's adjoint uses exp for the primal, so it only introduces potential inaccuracy in the pullback though. However, it doesn't follow the same time complexity as the primal function, and it is quite wasteful.

EDIT: everything said below is still valid, but it is a general property of power series matrix functions with real coefficients that the pullback is the pushforward pre- and post- composed with adjoint, or, equivalently, the pushforward of the function applied to the adjoint of the primal. i.e. if Y = f(A), then we have the equality

(f^*)_{Y} (ΔY) = (f_*)_{A'} (ΔY)

This applies to exp, log, and all trigonometric and hyperbolic functions.

Y = exp(A) appears in the solution to an ODE. We can augment the ODE to get a new one whose solution also uses the matrix exponential and which gives us the pushforward (discussed in section 7 of https://www2.humusoft.cz/www/papers/tcp08/017_brancik.pdf, though the result is older and can be worked out from https://ieeexplore.ieee.org/document/1101743 or witty algebra).
In short, given B = [A ΔA; zero(A) A], then exp(B) = [Y ∂Y; zero(Y) Y]

I didn't find a reference for this, but we can do the same thing for the pullback by constructing and solving the adjoint ODE
Given B = [A ΔY'; zero(A) A], then exp(B) = [Y ∂A'; zero(Y) Y]
That is, the pullback exp^* is related to the pushforward exp_{*} by exp^* = adjoint ∘ exp_{*} ∘ adjoint.
This is easy to verify:

julia> using LinearAlgebra, FiniteDifferences

julia> A, Δ = randn(ComplexF64, 30, 30), randn(ComplexF64, 30, 30);

julia> only(j′vp(central_fdm(5, 1), exp, Δ, A))  jvp(central_fdm(5, 1), exp, (A, Δ'))'
true

The problem with the augmented matrix approach is that it is 8x the cost of the primal, when we should be able to get <5x. For small matrices (<100x100) this is faster than the eigendecomposition approach, and it should be more accurate, but for large dense matrices, the eigendecomposition approach is faster.~

But the relationship between the pushforward and pullback motivates a solution. Namely, explicitly implement the pushforward of the scaling and squaring approach used by LinearAlgebra.exp!. Not only do we get the pushforward with the same time complexity as the primal, but we can then compute the pullback with the same time complexity of the primal without the need to checkpoint any of the intermediate matrices, and with mutation allowed.

I'm planning to tackle this after some of my other open PRs are wrapped up, but I wanted to get it in writing while it was fresh on my mind.

@wangleiphy
Copy link

Thanks for sharing your nice observation!
A question: how do you plan to "explicitly implement the pushforward of the scaling and squaring approach" ?

@sethaxen
Copy link
Member Author

Thanks for sharing your nice observation!

No problem! I have updated with a more general statement

A question: how do you plan to "explicitly implement the pushforward of the scaling and squaring approach" ?

One of two ways:

  1. Implement the Frechet derivative (i.e. pushforward) of the matrix exponential from http://eprints.ma.man.ac.uk/1218/1/covered/MIMS_ep2008_26.pdf, which fuses computation of the exponential and its pushforward. On the one hand, this algorithm is widely used, and the paper does extensive analysis of its properties. I just need to check if its computation of the exponential is equivalent or in any way inferior to what LinearAlgebra.exp! currently does.
  2. Hand-derive the pushforward of LinearAlgebra.exp!. At the simplest level, it just replaces each function call with its fused pushforward, doing what any forward-mode AD package would do. Then analyzing to remove any terms that cancel or reduce computation.

I'll probably do (1).

@wangleiphy
Copy link

wangleiphy commented Dec 24, 2020

Thanks.

If I understand corretly, this PR in Jax switched from (1) to (2) for the JVP of expm: jax-ml/jax#4314

and this one jax-ml/jax#4331 implements expm_frechet using jvp(expm), that is implementing (1) using (2).

@sethaxen
Copy link
Member Author

Yeah, the way jax does it is definitely better. Here we're limited in a sense by the fact that this package provides rules for all ChainRules-compatible ADs. We currently don't have a way to embed an automatically differentiated function within a custom rule, though that is planned (JuliaDiff/ChainRulesCore.jl#68). But even that wouldn't help here, where we want to compose a pushforward and a pullback, which will in general be provided by different AD packages. Part of the cost of being general.

@wangleiphy
Copy link

Thanks!

I believe this is also related https://github.com/Lezcano/expRNN/blob/830ec836521d0c295436dcafc3f0b3deea36c83c/trivializations.py#L19

Though, I am not sure why there is only one transpose...

@oxinabox
Copy link
Member

oxinabox commented Dec 24, 2020

(JuliaDiff/ChainRulesCore.jl#68). But even that wouldn't help here, where we want to compose a pushforward and a pullback, which will in general be provided by different AD packages.

JuliaDiff/ChainRulesCore.jl#68 should be able to do that.
Though it will require the reverse mode AD to specify a forward mode AD.
Nabla does that already, where it has a little forward mode AD using DualNumbers.jl (in the ChainRules PR this is changed to ForwardDiff.jl) that is uses internally in its rules for functions like map .
I have some designs about how it will be done so as to just not have rules if you don't have all the AD systems required.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants