Skip to content

Commit

Permalink
Symmetric/Hermitian matrix function rules (#193)
Browse files Browse the repository at this point in the history
* Add symmetric/hermitian eigendecomposition rules

* Add utility functions

* Add frules and rrules for sym/herm power series

* Add int pow rules

* Add sincos rules

* Remove unused function argument

* Fix and comment _nonzero

* Make methods and signatures less ambiguous

* Handle Zero() better

* Standardize notation

* Remove parens

* Update src/rulesets/LinearAlgebra/structured.jl

Co-authored-by: Nick Robinson <npr251@gmail.com>

* Fix for Julia 1.0

* Use correct variable and method name

* Accumulate in the triangle in the pullback

* Remove comment

* Add eigen and eigvals tests

* Remove outdated comment

* Clean up and make constraint functions faster

* Make outputs of int pow of Hermitian are Hermitian

* Fix typo in comment

* Test most power series functions

* Don't thunk tangents

* Make type-stable and use optimal threshold

* Split out symmetric/hermitian methods/tests

* Use correct pullback of hermitrization

* Stabilize eigenvector computation

* Test composed pullback

* Remove all eigendecomposition rules

Moved to #323

* Move to utilities section

* Move to utilities section

* Separate shared code into its own function

* Don't thunk

* Use correct function name

* Correctly broadcast

* Remove power rules

* Rename to matrix functions

* Remove pow tests

* Expand test suite

* Remove sincos rules for now

* Add references and comments

* Add _isindomain

* Refactor _matfun

Return cache and handle type-unstable case

* Add _matfun_frechet

* Broadcast instead of indexing

* Add comments and use indexing from paper

* Handle Zeros

* Contrain differentials according to primals

* Support all matrix functions

* Remove unused methods

* Support Symmetric{Complex}

* Add rules for sincos

* Make atanh rule type-stable

* Correctly test type-unstable functions

* Use correct denominator

* Add tests for almost-singular and low-rank matrices

* Remove out-dated comments

* Test alternate differentials

* Don't use only

* Remove _hermitrizeback!

* Don't use hasproperty, not in old Julia versions

* Reduce allocations

* simplify section name

* Simplify line

* Handle mixture of non-Zero and Zero

* Don't loop over unused functions

* Test against component frules instead of fd

* Test that rules produce same uplo as primal

* Apply suggestions from code review

Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>

* Reuse variable name

In this case it is type-stable

* Use bang bang convention for maybe-in-place

* Don't assume the wrapped matrix is mutable

* Replace hermitrize!

* Use diagind

* Remove handling of Zero differential

* Unify symbols

* Use hasproperty

* Load hasproperty from Compat

* Replace refs with one to Higham

* Add docstrings

* Update src/rulesets/LinearAlgebra/symmetric.jl

Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>

* Increment version number

* Use utility function

* Stabilize jvp Jacobian dimensions

* Don't use non-exported function

* Bump required ChainRulesCore

Co-authored-by: Nick Robinson <npr251@gmail.com>
Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
  • Loading branch information
3 people authored Jan 14, 2021
1 parent f540c44 commit c59e6f8
Show file tree
Hide file tree
Showing 4 changed files with 402 additions and 10 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.46"
version = "0.7.47"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -12,7 +12,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.9.25"
ChainRulesCore = "0.9.26"
ChainRulesTestUtils = "0.5, 0.6.1"
Compat = "3"
FiniteDifferences = "0.11, 0.12"
Expand Down
193 changes: 186 additions & 7 deletions src/rulesets/LinearAlgebra/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,19 @@ end
# ∂U is overwritten if not an `AbstractZero`
function eigen_rev!(A::LinearAlgebra.RealHermSymComplexHerm, λ, U, ∂λ, ∂U)
∂λ isa AbstractZero && ∂U isa AbstractZero && return ∂λ + ∂U
∂A = similar(A, eltype(U))
= similar(parent(A), eltype(U))
tmp = ∂U
if ∂U isa AbstractZero
mul!(∂A.data, U, real.(∂λ) .* U')
mul!(, U, real.(∂λ) .* U')
else
_eigen_norm_phase_rev!(∂U, A, U)
∂K = mul!(∂A.data, U', ∂U)
∂K = mul!(, U', ∂U)
∂K ./= λ' .- λ
∂K[diagind(∂K)] .= real.(∂λ)
mul!(tmp, ∂K, U')
mul!(∂A.data, U, tmp)
@inbounds _hermitrize!(∂A.data)
mul!(Ā, U, tmp)
end
∂A = _hermitrizelike!(Ā, A)
return ∂A
end

Expand Down Expand Up @@ -279,6 +279,172 @@ function rrule(::typeof(svdvals), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS
return S, svdvals_pullback
end

#####
##### matrix functions
#####

# Formula for frule (Fréchet derivative) from Daleckiĭ-Kreĭn theorem given in Theorem 3.11 of
# Higham N.J. Functions of Matrices: Theory and Computation. 2008. ISBN: 978-0-898716-46-7.
# rrule is derived from frule. These rules are more stable for degenerate matrices than
# applying the chain rule to the rules for `eigen`.

for func in (:exp, :log, :sqrt, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :asin, :atan, :acosh, :asinh, :atanh)
@eval begin
function frule((_, ΔA), ::typeof($func), A::LinearAlgebra.RealHermSymComplexHerm)
Y, intermediates = _matfun($func, A)
= _matfun_frechet($func, A, Y, ΔA, intermediates)
# If ΔA was hermitian, then ∂Y has the same structure as Y
∂Y = if ishermitian(ΔA) && (isa(Y, Symmetric) || isa(Y, Hermitian))
_symhermlike!(Ȳ, Y)
else
end
return Y, ∂Y
end

function rrule(::typeof($func), A::LinearAlgebra.RealHermSymComplexHerm)
Y, intermediates = _matfun($func, A)
function $(Symbol(func, :_pullback))(ΔY)
# for Hermitian Y, we don't need to realify the diagonal of ΔY, since the
# effect is the same as applying _hermitrizelike! at the end
∂Y = eltype(Y) <: Real ? real(ΔY) : ΔY
# for matrix functions, the pullback is related to the pushforward by an adjoint
= _matfun_frechet($func, A, Y, ∂Y', intermediates)
# the cotangent of Hermitian A should be Hermitian
∂A = _hermitrizelike!(Ā, A)
return NO_FIELDS, ∂A
end
return Y, $(Symbol(func, :_pullback))
end
end
end

function frule((_, ΔA), ::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm)
sinA, (λ, U, sinλ, cosλ) = _matfun(sin, A)
cosA = _symhermtype(sinA)((U * Diagonal(cosλ)) * U')
# We will overwrite tmp matrix several times to hold different values
tmp = mul!(similar(U, Base.promote_eltype(ΔA, U)), ΔA, U)
∂Λ = mul!(similar(U), U', tmp)
∂sinΛ = _muldiffquotmat!!(similar(∂Λ), sin, λ, sinλ, cosλ, ∂Λ)
∂cosΛ = _muldiffquotmat!!(∂Λ, cos, λ, cosλ, -sinλ, ∂Λ)
∂sinA = _symhermlike!(mul!(∂sinΛ, U, mul!(tmp, ∂sinΛ, U')), sinA)
∂cosA = _symhermlike!(mul!(∂cosΛ, U, mul!(tmp, ∂cosΛ, U')), cosA)
Y = (sinA, cosA)
∂Y = Composite{typeof(Y)}(∂sinA, ∂cosA)
return Y, ∂Y
end

function rrule(::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm)
sinA, (λ, U, sinλ, cosλ) = _matfun(sin, A)
cosA = typeof(sinA)((U * Diagonal(cosλ)) * U', sinA.uplo)
Y = (sinA, cosA)
function sincos_pullback((ΔsinA, ΔcosA)::Composite)
ΔsinA isa AbstractZero && ΔcosA isa AbstractZero && return NO_FIELDS, ΔsinA + ΔcosA
if eltype(A) <: Real
ΔsinA, ΔcosA = real(ΔsinA), real(ΔcosA)
end
if ΔcosA isa AbstractZero
= _matfun_frechet(sin, A, sinA, ΔsinA, (λ, U, sinλ, cosλ))
elseif ΔsinA isa AbstractZero
= _matfun_frechet(cos, A, cosA, ΔcosA, (λ, U, cosλ, -sinλ))
else
# we will overwrite tmp with various temporary values during this computation
tmp = mul!(similar(U, Base.promote_eltype(U, ΔsinA, ΔcosA)), ΔsinA, U)
∂sinΛ = mul!(similar(tmp), U', tmp)
∂cosΛ = U' * mul!(tmp, ΔcosA, U)
∂Λ = _muldiffquotmat!!(∂sinΛ, sin, λ, sinλ, cosλ, ∂sinΛ)
∂Λ = _muldiffquotmat!!(∂Λ, cos, λ, cosλ, -sinλ, ∂cosΛ, true)
= mul!(∂Λ, U, mul!(tmp, ∂Λ, U'))
end
∂A = _hermitrizelike!(Ā, A)
return NO_FIELDS, ∂A
end
return Y, sincos_pullback
end

"""
_matfun(f, A::LinearAlgebra.RealHermSymComplexHerm)
Compute the matrix function `f(A)` for real or complex hermitian `A`.
The function returns a tuple containing the result and a tuple of intermediates to be
reused by `_matfun_frechet` to compute the Fréchet derivative.
Note any function `f` used with this **must** have a `frule` defined on it.
"""
function _matfun(f, A::LinearAlgebra.RealHermSymComplexHerm)
λ, U = eigen(A)
if all(λi -> _isindomain(f, λi), λ)
fλ_df_dλ = map(λi -> frule((Zero(), One()), f, λi), λ)
else # promote to complex if necessary
fλ_df_dλ = map(λi -> frule((Zero(), One()), f, complex(λi)), λ)
end
= first.(fλ_df_dλ)
df_dλ = last.(unthunk.(fλ_df_dλ))
fA = (U * Diagonal(fλ)) * U'
Y = if eltype(A) <: Real
Symmetric(fA)
elseif eltype(fλ) <: Complex
fA
else
Hermitian(fA)
end
intermediates = (λ, U, fλ, df_dλ)
return Y, intermediates
end

# Computes ∂Y = U * (P .* (U' * ΔA * U)) * U' with fewer allocations
"""
_matfun_frechet(f, A::RealHermSymComplexHerm, Y, ΔA, intermediates)
Compute the Fréchet derivative of the matrix function `Y=f(A)`, where the Fréchet derivative
of `A` is `ΔA`, and `intermediates` is the second argument returned by `_matfun`.
"""
function _matfun_frechet(f, A::LinearAlgebra.RealHermSymComplexHerm, Y, ΔA, (λ, U, fλ, df_dλ))
# We will overwrite tmp matrix several times to hold different values
tmp = mul!(similar(U, Base.promote_eltype(U, ΔA)), ΔA, U)
∂Λ = mul!(similar(tmp), U', tmp)
∂fΛ = _muldiffquotmat!!(∂Λ, f, λ, fλ, df_dλ, ∂Λ)
# reuse intermediate if possible
if eltype(tmp) <: Real && eltype(∂fΛ) <: Complex
tmp2 = ∂fΛ * U'
else
tmp2 = mul!(tmp, ∂fΛ, U')
end
∂Y = mul!(∂fΛ, U, tmp2)
return ∂Y
end

# difference quotient, i.e. Pᵢⱼ = (f(λⱼ) - f(λᵢ)) / (λⱼ - λᵢ), with f'(λᵢ) when λᵢ=λⱼ
function _diffquot(f, λi, λj, fλi, fλj, ∂fλi, ∂fλj)
T = Base.promote_typeof(λi, λj, fλi, fλj, ∂fλi, ∂fλj)
Δλ = λj - λi
iszero(Δλ) && return T(∂fλi)
# handle round-off error using Maclaurin series of (f(λᵢ + Δλ) - f(λᵢ)) / Δλ wrt Δλ
# and approximating f''(λᵢ) with forward difference (f'(λᵢ + Δλ) - f'(λᵢ)) / Δλ
# so (f(λᵢ + Δλ) - f(λᵢ)) / Δλ = (f'(λᵢ + Δλ) + f'(λᵢ)) / 2 + O(Δλ^2)
# total error on the order of f(λᵢ) * eps()^(2/3)
abs(Δλ) < cbrt(eps(real(T))) && return T((∂fλj + ∂fλi) / 2)
Δfλ = fλj - fλi
return T(Δfλ / Δλ)
end

# broadcast multiply Δ by the matrix of difference quotients P, storing the result in PΔ.
# If β is is nonzero, then @. PΔ = β*PΔ + P*Δ
# if type of PΔ is incompatible with result, new matrix is allocated
function _muldiffquotmat!!(PΔ, f, λ, fλ, ∂fλ, Δ, β = false)
if eltype(PΔ) <: Real && eltype(fλ) <: Complex
PΔ2 = similar(PΔ, complex(eltype(PΔ)))
return _muldiffquotmat!!(PΔ2, f, λ, fλ, ∂fλ, Δ, β)
else
PΔ .= β .*.+ _diffquot.(f, λ, λ', fλ, transpose(fλ), ∂fλ, transpose(∂fλ)) .* Δ
return
end
end

_isindomain(f, x) = true
_isindomain(::Union{typeof(acos),typeof(asin)}, x::Real) = -1 x 1
_isindomain(::typeof(acosh), x::Real) = x 1
_isindomain(::Union{typeof(log),typeof(sqrt)}, x::Real) = x 0

#####
##### utilities
#####
Expand All @@ -288,8 +454,21 @@ _symhermtype(::Type{<:Symmetric}) = Symmetric
_symhermtype(::Type{<:Hermitian}) = Hermitian
_symhermtype(A) = _symhermtype(typeof(A))

function _realifydiag!(A)
for i in diagind(A)
@inbounds A[i] = real(A[i])
end
return A
end

function _symhermlike!(A, S::Union{Symmetric,Hermitian})
A isa Hermitian{<:Complex} && _realifydiag!(A)
return typeof(S)(A, S.uplo)
end

# in-place hermitrize matrix
function _hermitrize!(A)
function _hermitrizelike!(A_, S::LinearAlgebra.RealHermSymComplexHerm)
A = eltype(S) <: Real ? real(A_) : A_
n = size(A, 1)
for i in 1:n
for j in (i + 1):n
Expand All @@ -298,5 +477,5 @@ function _hermitrize!(A)
end
A[i, i] = real(A[i, i])
end
return A
return _symhermtype(S)(A, Symbol(S.uplo))
end
Loading

2 comments on commit c59e6f8

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/27987

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.47 -m "<description of version>" c59e6f8a5c0551358e64c72817053a7075390443
git push origin v0.7.47

Please sign in to comment.