Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rules for dense matrix exponential #351

Merged
merged 34 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
6e8358c
Add matfun.jl file
sethaxen Jan 18, 2021
f987f35
Add matfun docstrings
sethaxen Jan 18, 2021
e6b92c9
Add exp matrix function
sethaxen Jan 18, 2021
7624ee5
At least store one intermediate
sethaxen Jan 18, 2021
a7792a5
Test exp!
sethaxen Jan 18, 2021
6d6b4cb
Make pullback type-inferrable
sethaxen Jan 18, 2021
3645d75
Add clearer test label
sethaxen Jan 18, 2021
937f2ac
Create as hermitian
sethaxen Jan 18, 2021
b48204c
Test rrule
sethaxen Jan 18, 2021
2c19bba
Add comment about relationship between pushforward and pullback
sethaxen Jan 18, 2021
58f6005
Add header
sethaxen Jan 18, 2021
6ee1759
Add reference to Frechet deriv paper
sethaxen Jan 18, 2021
b1a2980
Run JuliaFormatter
sethaxen Jan 18, 2021
e860b3e
Reduce comment spacing from code
sethaxen Jan 18, 2021
8f665ac
Update src/rulesets/LinearAlgebra/matfun.jl
sethaxen Jan 18, 2021
9e565ae
Correctly handle balancing
sethaxen Jan 18, 2021
71134fd
Test imbalanced matrix A
sethaxen Jan 18, 2021
bd48565
Increment version number
sethaxen Jan 18, 2021
062b11d
Merge branch 'exp2' of https://github.com/sethaxen/ChainRules.jl into…
sethaxen Jan 18, 2021
dc1b1ab
Apply suggestions from code review
sethaxen Jan 19, 2021
57aea17
Change signature of _matfun_frechet
sethaxen Jan 20, 2021
e2e6605
Give math for Frechet derivative
sethaxen Jan 20, 2021
976af09
Change Frechet notation
sethaxen Jan 20, 2021
d7d20ba
Add _matfun_frechet_adjoint
sethaxen Jan 20, 2021
b0ae61c
Simplify hermitian code
sethaxen Jan 20, 2021
62b963b
Correct comment
sethaxen Jan 20, 2021
87e4c53
Remove comments
sethaxen Jan 20, 2021
9bd06b1
Use abbreviated SHA
sethaxen Jan 20, 2021
2ed06e6
Link
sethaxen Jan 20, 2021
156e6f5
Update comment
sethaxen Jan 20, 2021
9a63d13
Move comment up
sethaxen Jan 20, 2021
5ba193d
Move comment further up
sethaxen Jan 20, 2021
49df929
Update docstrings
sethaxen Jan 20, 2021
8c27276
Push header to same level as rules
sethaxen Jan 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ include("rulesets/LinearAlgebra/utils.jl")
include("rulesets/LinearAlgebra/blas.jl")
include("rulesets/LinearAlgebra/dense.jl")
include("rulesets/LinearAlgebra/norm.jl")
include("rulesets/LinearAlgebra/matfun.jl")
include("rulesets/LinearAlgebra/structured.jl")
include("rulesets/LinearAlgebra/symmetric.jl")
include("rulesets/LinearAlgebra/factorization.jl")
Expand Down
243 changes: 243 additions & 0 deletions src/rulesets/LinearAlgebra/matfun.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
# matrix functions of dense matrices
sethaxen marked this conversation as resolved.
Show resolved Hide resolved

# NOTE: for matrix functions whose power series representation has real coefficients,
# the pullback and pushforward are related by an adjoint.
# Specifically, if the pushforward of f(A) is (f_*)_A(ΔA), then the pullback at Y=f(A) is
# (f^*)_Y(ΔY) = (f_*)_{A'}(ΔY) = ((f_*)_A(ΔY'))'
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well that is hideous, but notation is hard, and harder in unicode.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unicode's missing subscripts make it extra hard.

Idea that might not be worth doing:
What if we just made a section for this in the docs, (maybe as internal notes or something)
and wrote the latex and then linked to that?

But yeah notation for pullbacks and pushforwards is hard.
It has to convey so much state

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bit makes sense:
(f^*)_Y(ΔY) = ((f_*)_A(ΔY'))'
so the pullback at A, i.e. the pullback from Y (though that's not well defined since not all functions are monotonic?)
is equal to the the adjoint of the pushing forward at A, the adjoint of of the output senstivity.
the fact that that is also equal to (f_*)_{A'}(ΔY) is pretty magic.

Magical expodential symmetry? (I feel like i made the same suprised sounds for the same reason on your last PR)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we just made a section for this in the docs, (maybe as internal notes or something)
and wrote the latex and then linked to that?

Hm, that's an idea. I'll consider it, potentially for a future PR.

This bit makes sense:
(f^*)_Y(ΔY) = ((f_*)_A(ΔY'))'
so the pullback at A, i.e. the pullback from Y...is equal to the the adjoint of the pushing forward at A, the adjoint of of the output senstivity.

Ah yes your description is correct (although it's the adjoint of the pushing forward of the adjoint). I just checked Lee, and this should be the right notation:

Suggested change
# (f^*)_Y(ΔY) = (f_*)_{A'}(ΔY) = ((f_*)_A(ΔY'))'
# (f^*)_A(ΔY) = (f_*)_{A'}(ΔY) = ((f_*)_A(ΔY'))'

(though that's not well defined since not all functions are monotonic?)

I'm not sure what you mean by this.

Magical expodential symmetry? (I feel like i made the same suprised sounds for the same reason on your last PR)

It's still surprising to me. Although this property is general for all of the matrix functions defined in LinearAlgebra, not just exp. It doesn't follow for all matrix functions though, just those whose convergent power series have real coefficients.

# So we reuse the code from the pushforward to implement the pullback.

"""
_matfun(f, A) -> (Y, intermediates)

Compute the matrix function `Y=f(A)` for matrix `A`.
The function returns a tuple containing the result and a tuple of intermediates to be
reused by `_matfun_frechet` to compute the Fréchet derivative.
Note that any function `f` used with this **must** have a `frule` defined on it.
"""
_matfun

"""
_matfun!(f, A) -> (Y, intermediates)

Similar to [`_matfun`](@ref), but where `A` may be overwritten.
"""
_matfun!

"""
_matfun_frechet(f, A, Y, ΔA, intermediates)

Compute the Fréchet derivative of the matrix function `Y=f(A)`, where the Fréchet derivative
of `A` is `ΔA`, and `intermediates` is the second argument returned by `_matfun`.
"""
_matfun_frechet

"""
_matfun_frechet!(f, A, Y, ΔA, intermediates)

Similar to `_matfun_frechet!`, but where `ΔA` may be overwritten.
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
"""
_matfun_frechet!

#####
##### `exp`/`exp!`
#####

function frule((_, ΔA), ::typeof(LinearAlgebra.exp!), A::StridedMatrix{<:BlasFloat})
if ishermitian(A)
hermX, ∂hermX = frule((Zero(), ΔA), exp, Hermitian(A))
X = LinearAlgebra.copytri!(parent(hermX), 'U', true)
if ∂hermX isa LinearAlgebra.RealHermSymComplexHerm
∂X = LinearAlgebra.copytri!(parent(∂hermX), 'U', true)
else
∂X = ∂hermX
end
else
X, intermediates = _matfun!(exp, A)
∂X = _matfun_frechet!(exp, A, X, ΔA, intermediates)
end
return X, ∂X
end

function rrule(::typeof(exp), A0::StridedMatrix{<:BlasFloat})
# TODO: try to make this more type-stable
if ishermitian(A0)
# call _matfun instead of the rrule to avoid hermitrizing ∂A in the pullback
hermA = Hermitian(A0)
hermX, hermX_intermediates = _matfun(exp, hermA)
function exp_pullback_hermitian(ΔX)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
∂hermA = _matfun_frechet(exp, hermA, hermX, ΔX, hermX_intermediates)
∂hermA isa LinearAlgebra.RealHermSymComplexHerm || return NO_FIELDS, ∂hermA
return NO_FIELDS, parent(∂hermA)
end
return LinearAlgebra.copytri!(parent(hermX), 'U', true), exp_pullback_hermitian
else
A = copy(A0)
X, intermediates = _matfun!(exp, A)
function exp_pullback(ΔX)
ΔX′ = copy(adjoint(ΔX))
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
∂A′ = _matfun_frechet!(exp, A, X, ΔX′, intermediates)
∂A = copy(adjoint(∂A′))
return NO_FIELDS, ∂A
end
return X, exp_pullback
end
end

## Destructive matrix exponential using algorithm from Higham, 2008,
## "Functions of Matrices: Theory and Computation", SIAM
## Adapted from LinearAlgebra.exp! with return of intermediates
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
function _matfun!(::typeof(exp), A::StridedMatrix{T}) where T<:BlasFloat
n = LinearAlgebra.checksquare(A)
ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A
nA = opnorm(A, 1)
Inn = Matrix{T}(I, n, n)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
## For sufficiently small nA, use lower order Padé-Approximations
if (nA <= 2.1)
if nA > 0.95
C = T[17643225600.,8821612800.,2075673600.,302702400.,
30270240., 2162160., 110880., 3960.,
90., 1.]
elseif nA > 0.25
C = T[17297280.,8648640.,1995840.,277200.,
25200., 1512., 56., 1.]
elseif nA > 0.015
C = T[30240.,15120.,3360.,
420., 30., 1.]
else
C = T[120.,60.,12.,1.]
end
si = 0
else
C = T[64764752532480000.,32382376266240000.,7771770303897600.,
1187353796428800., 129060195264000., 10559470521600.,
670442572800., 33522128640., 1323241920.,
40840800., 960960., 16380.,
182., 1.]
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
s = log2(nA/5.4) # power of 2 later reversed by squaring
si = ceil(Int,s)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
end

if si > 0
A ./= convert(T,2^si)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
end

A2 = A * A
P = copy(Inn)
W = C[2] * P
V = C[1] * P
Apows = typeof(P)[]
for k in 1:(div(size(C, 1), 2) - 1)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
k2 = 2 * k
P *= A2
push!(Apows, P)
W += C[k2 + 2] * P
V += C[k2 + 1] * P
end
U = A * W
X = V + U
F = lu!(V-U) # NOTE: use lu! instead of LAPACK.gesv! so we can reuse factorization
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
ldiv!(F, X)
Xpows = typeof(X)[X]
if si > 0 # squaring to reverse dividing by power of 2
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
for t=1:si
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
X *= X
push!(Xpows, X)
end
end

# Undo the balancing
for j = ilo:ihi
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
scj = scale[j]
for i = 1:n
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
X[j,i] *= scj
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
end
for i = 1:n
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
X[i,j] /= scj
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
end
end

if ilo > 1 # apply lower permutations in reverse order
for j in (ilo-1):-1:1; LinearAlgebra.rcswap!(j, Int(scale[j]), X) end
end
if ihi < n # apply upper permutations in forward order
for j in (ihi+1):n; LinearAlgebra.rcswap!(j, Int(scale[j]), X) end
end
return X, (ilo, ihi, scale, C, si, Apows, W, F, Xpows)
end

# Application of the chain rule to exp!, also Algorithm 6.4 from
# Al-Mohy, Awad H. and Higham, Nicholas J. (2009).
# Computing the Fréchet Derivative of the Matrix Exponential, with an application to
# Condition Number Estimation", SIAM. 30 (4). pp. 1639-1657.
# http://eprints.maths.manchester.ac.uk/id/eprint/1218
function _matfun_frechet!(
::typeof(exp),
A::StridedMatrix{T},
X,
ΔA,
(ilo, ihi, scale, C, si, Apows, W, F, Xpows),
) where {T<:BlasFloat}
n = LinearAlgebra.checksquare(A)
for j = ilo:ihi
scj = scale[j]
for i = 1:n
ΔA[j,i] /= scj
end
for i = 1:n
ΔA[i,j] *= scj
end
end

if si > 0
ΔA ./= convert(T, 2^si)
end

∂A2 = mul!(A * ΔA, ΔA, A, true, true)
A2 = first(Apows)
# we will repeatedly overwrite ∂temp and ∂P below
∂temp = Matrix{eltype(∂A2)}(undef, n, n)
∂P = copy(∂A2)
∂W = C[4] * ∂P
∂V = C[3] * ∂P
for k in 2:(length(Apows)-1)
k2 = 2 * k
P = Apows[k - 1]
∂P, ∂temp = mul!(mul!(∂temp, ∂P, A2), P, ∂A2, true, true), ∂P
axpy!(C[k2 + 2], ∂P, ∂W)
axpy!(C[k2 + 1], ∂P, ∂V)
end
∂U, ∂temp = mul!(mul!(∂temp, A, ∂W), ΔA, W, true, true), ∂W
∂temp .= ∂U .- ∂V
∂X = add!!(∂U, ∂V)
mul!(∂X, ∂temp, first(Xpows), true, true)
ldiv!(F, ∂X)

if si > 0
for t = 1:(length(Xpows)-1)
X = Xpows[t]
∂X, ∂temp = mul!(mul!(∂temp, X, ∂X), ∂X, X, true, true), ∂X
end
end

for j = ilo:ihi
scj = scale[j]
for i = 1:n
∂X[j,i] *= scj
end
for i = 1:n
∂X[i,j] /= scj
end
end

if ilo > 1 # apply lower permutations in reverse order
for j in (ilo-1):-1:1
LinearAlgebra.rcswap!(j, Int(scale[j]), ∂X)
end
end
if ihi < n # apply upper permutations in forward order
for j in (ihi+1):n
LinearAlgebra.rcswap!(j, Int(scale[j]), ∂X)
end
end
return ∂X
end
6 changes: 0 additions & 6 deletions src/rulesets/LinearAlgebra/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -392,12 +392,6 @@ function _matfun(f, A::LinearAlgebra.RealHermSymComplexHerm)
end

# Computes ∂Y = U * (P .* (U' * ΔA * U)) * U' with fewer allocations
"""
_matfun_frechet(f, A::RealHermSymComplexHerm, Y, ΔA, intermediates)

Compute the Fréchet derivative of the matrix function `Y=f(A)`, where the Fréchet derivative
of `A` is `ΔA`, and `intermediates` is the second argument returned by `_matfun`.
"""
function _matfun_frechet(f, A::LinearAlgebra.RealHermSymComplexHerm, Y, ΔA, (λ, U, fλ, df_dλ))
# We will overwrite tmp matrix several times to hold different values
tmp = mul!(similar(U, Base.promote_eltype(U, ΔA)), ΔA, U)
Expand Down
36 changes: 36 additions & 0 deletions test/rulesets/LinearAlgebra/matfun.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
@testset "matrix functions" begin
@testset "LinearAlgebra.exp!(A::Matrix) frule" begin
n = 10
@testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64), nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0)
A, ΔA = randn(ComplexF64, n, n), randn(ComplexF64, n, n)
# choose normalization to hit specific branch
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
A *= nrm / opnorm(A, 1)
frule_test(LinearAlgebra.exp!, (A, ΔA))
end
@testset "hermitian A" begin
A, ΔA = Matrix(Hermitian(randn(ComplexF64, n, n))), randn(ComplexF64, n, n)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
frule_test(LinearAlgebra.exp!, (A, Matrix(Hermitian(ΔA))))
frule_test(LinearAlgebra.exp!, (A, ΔA))
end
end

@testset "exp(A::Matrix) rrule" begin
n = 10
@testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64), nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0)
A, ΔA = randn(ComplexF64, n, n), randn(ComplexF64, n, n)
ΔY = randn(ComplexF64, n, n)
# choose normalization to hit specific branch
A *= nrm / opnorm(A, 1)
# rrule is not inferrable, but pullback should be
rrule_test(exp, ΔY, (A, ΔA); check_inferred = false)
Y, back = rrule(exp, A)
@inferred back(ΔY)
end
@testset "hermitian A" begin
A, ΔA = Matrix(Hermitian(randn(ComplexF64, n, n))), randn(ComplexF64, n, n)
ΔY = randn(ComplexF64, n, n)
rrule_test(exp, Matrix(Hermitian(ΔY)), (A, ΔA); check_inferred = false)
rrule_test(exp, ΔY, (A, ΔA); check_inferred = false)
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ println("Testing ChainRules.jl")
@testset "LinearAlgebra" begin
include_test("rulesets/LinearAlgebra/dense.jl")
include_test("rulesets/LinearAlgebra/norm.jl")
include_test("rulesets/LinearAlgebra/matfun.jl")
include_test("rulesets/LinearAlgebra/structured.jl")
include_test("rulesets/LinearAlgebra/symmetric.jl")
include_test("rulesets/LinearAlgebra/factorization.jl")
Expand Down