diff --git a/Project.toml b/Project.toml index 6e1e51e54..add472ad1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.46" +version = "0.7.47" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -12,7 +12,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -ChainRulesCore = "0.9.25" +ChainRulesCore = "0.9.26" ChainRulesTestUtils = "0.5, 0.6.1" Compat = "3" FiniteDifferences = "0.11, 0.12" diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index c8d65885a..a1985ec40 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -129,19 +129,19 @@ end # ∂U is overwritten if not an `AbstractZero` function eigen_rev!(A::LinearAlgebra.RealHermSymComplexHerm, λ, U, ∂λ, ∂U) ∂λ isa AbstractZero && ∂U isa AbstractZero && return ∂λ + ∂U - ∂A = similar(A, eltype(U)) + Ā = similar(parent(A), eltype(U)) tmp = ∂U if ∂U isa AbstractZero - mul!(∂A.data, U, real.(∂λ) .* U') + mul!(Ā, U, real.(∂λ) .* U') else _eigen_norm_phase_rev!(∂U, A, U) - ∂K = mul!(∂A.data, U', ∂U) + ∂K = mul!(Ā, U', ∂U) ∂K ./= λ' .- λ ∂K[diagind(∂K)] .= real.(∂λ) mul!(tmp, ∂K, U') - mul!(∂A.data, U, tmp) - @inbounds _hermitrize!(∂A.data) + mul!(Ā, U, tmp) end + ∂A = _hermitrizelike!(Ā, A) return ∂A end @@ -279,6 +279,172 @@ function rrule(::typeof(svdvals), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS return S, svdvals_pullback end +##### +##### matrix functions +##### + +# Formula for frule (Fréchet derivative) from Daleckiĭ-Kreĭn theorem given in Theorem 3.11 of +# Higham N.J. Functions of Matrices: Theory and Computation. 2008. ISBN: 978-0-898716-46-7. +# rrule is derived from frule. These rules are more stable for degenerate matrices than +# applying the chain rule to the rules for `eigen`. + +for func in (:exp, :log, :sqrt, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :asin, :atan, :acosh, :asinh, :atanh) + @eval begin + function frule((_, ΔA), ::typeof($func), A::LinearAlgebra.RealHermSymComplexHerm) + Y, intermediates = _matfun($func, A) + Ȳ = _matfun_frechet($func, A, Y, ΔA, intermediates) + # If ΔA was hermitian, then ∂Y has the same structure as Y + ∂Y = if ishermitian(ΔA) && (isa(Y, Symmetric) || isa(Y, Hermitian)) + _symhermlike!(Ȳ, Y) + else + Ȳ + end + return Y, ∂Y + end + + function rrule(::typeof($func), A::LinearAlgebra.RealHermSymComplexHerm) + Y, intermediates = _matfun($func, A) + function $(Symbol(func, :_pullback))(ΔY) + # for Hermitian Y, we don't need to realify the diagonal of ΔY, since the + # effect is the same as applying _hermitrizelike! at the end + ∂Y = eltype(Y) <: Real ? real(ΔY) : ΔY + # for matrix functions, the pullback is related to the pushforward by an adjoint + Ā = _matfun_frechet($func, A, Y, ∂Y', intermediates) + # the cotangent of Hermitian A should be Hermitian + ∂A = _hermitrizelike!(Ā, A) + return NO_FIELDS, ∂A + end + return Y, $(Symbol(func, :_pullback)) + end + end +end + +function frule((_, ΔA), ::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm) + sinA, (λ, U, sinλ, cosλ) = _matfun(sin, A) + cosA = _symhermtype(sinA)((U * Diagonal(cosλ)) * U') + # We will overwrite tmp matrix several times to hold different values + tmp = mul!(similar(U, Base.promote_eltype(ΔA, U)), ΔA, U) + ∂Λ = mul!(similar(U), U', tmp) + ∂sinΛ = _muldiffquotmat!!(similar(∂Λ), sin, λ, sinλ, cosλ, ∂Λ) + ∂cosΛ = _muldiffquotmat!!(∂Λ, cos, λ, cosλ, -sinλ, ∂Λ) + ∂sinA = _symhermlike!(mul!(∂sinΛ, U, mul!(tmp, ∂sinΛ, U')), sinA) + ∂cosA = _symhermlike!(mul!(∂cosΛ, U, mul!(tmp, ∂cosΛ, U')), cosA) + Y = (sinA, cosA) + ∂Y = Composite{typeof(Y)}(∂sinA, ∂cosA) + return Y, ∂Y +end + +function rrule(::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm) + sinA, (λ, U, sinλ, cosλ) = _matfun(sin, A) + cosA = typeof(sinA)((U * Diagonal(cosλ)) * U', sinA.uplo) + Y = (sinA, cosA) + function sincos_pullback((ΔsinA, ΔcosA)::Composite) + ΔsinA isa AbstractZero && ΔcosA isa AbstractZero && return NO_FIELDS, ΔsinA + ΔcosA + if eltype(A) <: Real + ΔsinA, ΔcosA = real(ΔsinA), real(ΔcosA) + end + if ΔcosA isa AbstractZero + Ā = _matfun_frechet(sin, A, sinA, ΔsinA, (λ, U, sinλ, cosλ)) + elseif ΔsinA isa AbstractZero + Ā = _matfun_frechet(cos, A, cosA, ΔcosA, (λ, U, cosλ, -sinλ)) + else + # we will overwrite tmp with various temporary values during this computation + tmp = mul!(similar(U, Base.promote_eltype(U, ΔsinA, ΔcosA)), ΔsinA, U) + ∂sinΛ = mul!(similar(tmp), U', tmp) + ∂cosΛ = U' * mul!(tmp, ΔcosA, U) + ∂Λ = _muldiffquotmat!!(∂sinΛ, sin, λ, sinλ, cosλ, ∂sinΛ) + ∂Λ = _muldiffquotmat!!(∂Λ, cos, λ, cosλ, -sinλ, ∂cosΛ, true) + Ā = mul!(∂Λ, U, mul!(tmp, ∂Λ, U')) + end + ∂A = _hermitrizelike!(Ā, A) + return NO_FIELDS, ∂A + end + return Y, sincos_pullback +end + +""" + _matfun(f, A::LinearAlgebra.RealHermSymComplexHerm) + +Compute the matrix function `f(A)` for real or complex hermitian `A`. +The function returns a tuple containing the result and a tuple of intermediates to be +reused by `_matfun_frechet` to compute the Fréchet derivative. +Note any function `f` used with this **must** have a `frule` defined on it. +""" +function _matfun(f, A::LinearAlgebra.RealHermSymComplexHerm) + λ, U = eigen(A) + if all(λi -> _isindomain(f, λi), λ) + fλ_df_dλ = map(λi -> frule((Zero(), One()), f, λi), λ) + else # promote to complex if necessary + fλ_df_dλ = map(λi -> frule((Zero(), One()), f, complex(λi)), λ) + end + fλ = first.(fλ_df_dλ) + df_dλ = last.(unthunk.(fλ_df_dλ)) + fA = (U * Diagonal(fλ)) * U' + Y = if eltype(A) <: Real + Symmetric(fA) + elseif eltype(fλ) <: Complex + fA + else + Hermitian(fA) + end + intermediates = (λ, U, fλ, df_dλ) + return Y, intermediates +end + +# Computes ∂Y = U * (P .* (U' * ΔA * U)) * U' with fewer allocations +""" + _matfun_frechet(f, A::RealHermSymComplexHerm, Y, ΔA, intermediates) + +Compute the Fréchet derivative of the matrix function `Y=f(A)`, where the Fréchet derivative +of `A` is `ΔA`, and `intermediates` is the second argument returned by `_matfun`. +""" +function _matfun_frechet(f, A::LinearAlgebra.RealHermSymComplexHerm, Y, ΔA, (λ, U, fλ, df_dλ)) + # We will overwrite tmp matrix several times to hold different values + tmp = mul!(similar(U, Base.promote_eltype(U, ΔA)), ΔA, U) + ∂Λ = mul!(similar(tmp), U', tmp) + ∂fΛ = _muldiffquotmat!!(∂Λ, f, λ, fλ, df_dλ, ∂Λ) + # reuse intermediate if possible + if eltype(tmp) <: Real && eltype(∂fΛ) <: Complex + tmp2 = ∂fΛ * U' + else + tmp2 = mul!(tmp, ∂fΛ, U') + end + ∂Y = mul!(∂fΛ, U, tmp2) + return ∂Y +end + +# difference quotient, i.e. Pᵢⱼ = (f(λⱼ) - f(λᵢ)) / (λⱼ - λᵢ), with f'(λᵢ) when λᵢ=λⱼ +function _diffquot(f, λi, λj, fλi, fλj, ∂fλi, ∂fλj) + T = Base.promote_typeof(λi, λj, fλi, fλj, ∂fλi, ∂fλj) + Δλ = λj - λi + iszero(Δλ) && return T(∂fλi) + # handle round-off error using Maclaurin series of (f(λᵢ + Δλ) - f(λᵢ)) / Δλ wrt Δλ + # and approximating f''(λᵢ) with forward difference (f'(λᵢ + Δλ) - f'(λᵢ)) / Δλ + # so (f(λᵢ + Δλ) - f(λᵢ)) / Δλ = (f'(λᵢ + Δλ) + f'(λᵢ)) / 2 + O(Δλ^2) + # total error on the order of f(λᵢ) * eps()^(2/3) + abs(Δλ) < cbrt(eps(real(T))) && return T((∂fλj + ∂fλi) / 2) + Δfλ = fλj - fλi + return T(Δfλ / Δλ) +end + +# broadcast multiply Δ by the matrix of difference quotients P, storing the result in PΔ. +# If β is is nonzero, then @. PΔ = β*PΔ + P*Δ +# if type of PΔ is incompatible with result, new matrix is allocated +function _muldiffquotmat!!(PΔ, f, λ, fλ, ∂fλ, Δ, β = false) + if eltype(PΔ) <: Real && eltype(fλ) <: Complex + PΔ2 = similar(PΔ, complex(eltype(PΔ))) + return _muldiffquotmat!!(PΔ2, f, λ, fλ, ∂fλ, Δ, β) + else + PΔ .= β .* PΔ .+ _diffquot.(f, λ, λ', fλ, transpose(fλ), ∂fλ, transpose(∂fλ)) .* Δ + return PΔ + end +end + +_isindomain(f, x) = true +_isindomain(::Union{typeof(acos),typeof(asin)}, x::Real) = -1 ≤ x ≤ 1 +_isindomain(::typeof(acosh), x::Real) = x ≥ 1 +_isindomain(::Union{typeof(log),typeof(sqrt)}, x::Real) = x ≥ 0 + ##### ##### utilities ##### @@ -288,8 +454,21 @@ _symhermtype(::Type{<:Symmetric}) = Symmetric _symhermtype(::Type{<:Hermitian}) = Hermitian _symhermtype(A) = _symhermtype(typeof(A)) +function _realifydiag!(A) + for i in diagind(A) + @inbounds A[i] = real(A[i]) + end + return A +end + +function _symhermlike!(A, S::Union{Symmetric,Hermitian}) + A isa Hermitian{<:Complex} && _realifydiag!(A) + return typeof(S)(A, S.uplo) +end + # in-place hermitrize matrix -function _hermitrize!(A) +function _hermitrizelike!(A_, S::LinearAlgebra.RealHermSymComplexHerm) + A = eltype(S) <: Real ? real(A_) : A_ n = size(A, 1) for i in 1:n for j in (i + 1):n @@ -298,5 +477,5 @@ function _hermitrize!(A) end A[i, i] = real(A[i, i]) end - return A + return _symhermtype(S)(A, Symbol(S.uplo)) end diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index cb2976ad6..85d00e726 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -293,4 +293,217 @@ @test @inferred(back(Zero())) == (NO_FIELDS, Zero()) end end + + @testset "Symmetric/Hermitian matrix functions" begin + # generate random matrices of type TA in the domain of f + function rand_matfun_input(f, TA, T, uplo, n, hermout) + U = Matrix(qr(randn(T, n, n)).Q) + if hermout # f(A) will also be a TA + λ = if f in (acos, asin, atanh) + 2 .* rand(real(T), n) .- 1 + elseif f in (log, sqrt) + abs.(randn(real(T), n)) + elseif f === acosh + 1 .+ abs.(randn(real(T), n)) + else + randn(real(T), n) + end + else + λ = randn(real(T), n) + λ = if f === atanh + 2 .* rand(real(T), n) .- 1 + else + randn(real(T), n) + end + end + return TA(U * Diagonal(λ) * U', uplo) + end + + # Adapted From ChainRulesTestUtils._is_inferrable + function is_inferrable(f, A) + try + @inferred f(A) + return true + catch ErrorException + return false + end + end + + @testset "$(f)(::$TA{<:$T})" for f in + (exp, log, sqrt, cos, sin, tan, cosh, sinh, tanh, acos, asin, atan, acosh, asinh, atanh), + TA in (Symmetric, Hermitian), + T in (TA <: Symmetric ? (Float64,) : (Float64, ComplexF64)) + TC = Complex{real(T)} + + n = 10 + @testset "frule" begin + @testset for uplo in (:L, :U), hermout in (true, false) + A, ΔA = rand_matfun_input(f, TA, T, uplo, n, hermout), TA(randn(T, n, n), uplo) + Y = f(A) + if is_inferrable(f, A) + Y_ad, ∂Y_ad = @inferred frule((Zero(), ΔA), f, A) + else + TY = T∂Y = if T <: Real + Union{Symmetric{Complex{T}},Symmetric{T}} + else + Union{Matrix{T},Hermitian{T}} + end + Y_ad, ∂Y_ad = @inferred Tuple{TY,T∂Y} frule((Zero(), ΔA), f, A) + end + @test Y_ad == Y + @test typeof(Y_ad) === typeof(Y) + hasproperty(Y, :uplo) && @test Y_ad.uplo == Y.uplo + @test ∂Y_ad isa typeof(Y) + hasproperty(∂Y_ad, :uplo) && @test ∂Y_ad.uplo == Y.uplo + @test parent(∂Y_ad) ≈ jvp(_fdm, x -> Matrix{TC}(parent(f(TA(x, uplo)))), (A.data, ΔA.data)) + end + + @testset "stable for (almost-)singular input" begin + λ, U = eigen(rand_matfun_input(f, TA, T, :U, n, true)) + m = div(n, 2) + λ[1:m] .= λ[m+1:2m] .+ cbrt(eps(eltype(λ))) / 100 + A = TA(U * Diagonal(λ) * U') + ΔA = TA(randn(T, n, n)) + _, ∂Y = frule((Zero(), ΔA), f, A) + @test parent(∂Y) ≈ jvp(_fdm, x -> Matrix{TC}(parent(f(TA(x)))), (A.data, ΔA.data)) + + λ[1:m] .= λ[m+1:2m] + A2 = TA(U * Diagonal(λ) * U') + ΔA2 = TA(randn(T, n, n)) + _, ∂Y2 = frule((Zero(), ΔA2), f, A2) + @test parent(∂Y2) ≈ jvp(_fdm, x -> Matrix{TC}(parent(f(TA(x)))), (A2.data, ΔA2.data)) + end + + f ∉ (log,sqrt,acosh) && @testset "low-rank matrix" begin + λ, U = eigen(rand_matfun_input(f, TA, T, :U, n, true)) + λ[2:n] .= 0 + A = TA(U * Diagonal(λ) * U') + ΔA = TA(randn(T, n, n)) + _, ∂Y = frule((Zero(), ΔA), f, A) + @test parent(∂Y) ≈ jvp(_fdm, x -> Matrix{TC}(parent(f(TA(x)))), (A.data, ΔA.data)) + end + end + + @testset "rrule" begin + @testset for uplo in (:L, :U), hermout in (true, false) + A = rand_matfun_input(f, TA, T, uplo, n, hermout) + Y = f(A) + ΔY = if Y isa Matrix + randn(eltype(Y), n, n) + else + typeof(Y)(randn(eltype(Y), n, n), Y.uplo) + end + if is_inferrable(f, A) + Y_ad, back = @inferred rrule(f, A) + else + TY = if T <: Real + Union{Symmetric{Complex{T}},Symmetric{T}} + else + Union{Matrix{T},Hermitian{T}} + end + Y_ad, back = @inferred Tuple{TY,Any} rrule(f, A) + end + @test Y_ad == Y + @test typeof(Y_ad) === typeof(Y) + hasproperty(Y, :uplo) && @test Y_ad.uplo == Y.uplo + ∂self, ∂A = @inferred back(ΔY) + @test ∂self === NO_FIELDS + @test ∂A isa typeof(A) + @test ∂A.uplo == A.uplo + # check pullback composes correctly + ∂data = rrule(Hermitian, A.data, uplo)[2](∂A)[2] + @test ∂data ≈ j′vp(_fdm, x -> parent(f(TA(x, uplo))), ΔY, A.data)[1] + + # check works correctly even when cotangent is different type than Y + ΔY2 = randn(Complex{real(T)}, n, n) + _, ∂A2 = back(ΔY2) + ∂data2 = rrule(Hermitian, A.data, uplo)[2](∂A2)[2] + @test ∂data2 ≈ j′vp(_fdm, x -> Matrix{Complex{real(T)}}(f(TA(x, uplo))), ΔY2, A.data)[1] + end + + @testset "stable for (almost-)singular input" begin + λ, U = eigen(rand_matfun_input(f, TA, T, :U, n, true)) + m = div(n, 2) + λ[1:m] .= λ[m+1:2m] .+ cbrt(eps(eltype(λ))) / 100 + A = TA(U * Diagonal(λ) * U') + ΔY = TA(randn(T, n, n)) + ∂A = rrule(f, A)[2](ΔY)[2] + ∂data = rrule(Hermitian, A.data, :U)[2](∂A)[2] + @test ∂data ≈ j′vp(_fdm, x -> parent(f(TA(x))), ΔY, A.data)[1] + + λ[1:m] .= λ[m+1:2m] + A2 = TA(U * Diagonal(λ) * U') + ΔY2 = TA(randn(T, n, n)) + ∂A2 = rrule(f, A2)[2](ΔY2)[2] + ∂data2 = rrule(Hermitian, A2.data, :U)[2](∂A2)[2] + @test ∂data2 ≈ j′vp(_fdm, x -> parent(f(TA(x))), ΔY2, A2.data)[1] + end + + f ∉ (log,sqrt,acosh) && @testset "low-rank matrix" begin + λ, U = eigen(rand_matfun_input(f, TA, T, :U, n, true)) + λ[2:n] .= 0 + A = TA(U * Diagonal(λ) * U') + ΔY = TA(randn(T, n, n)) + ∂A = rrule(f, A)[2](ΔY)[2] + ∂data = rrule(Hermitian, A.data, :U)[2](∂A)[2] + @test ∂data ≈ j′vp(_fdm, x -> parent(f(TA(x))), ΔY, A.data)[1] + end + end + end + + @testset "sincos(::$TA{<:$T})" for TA in (Symmetric, Hermitian), + T in (TA <: Symmetric ? (Float64,) : (Float64, ComplexF64)) + + n = 10 + @testset "frule" begin + @testset for uplo in (:L, :U) + A, ΔA = TA(randn(T, n, n), uplo), TA(randn(T, n, n), uplo) + Y = sincos(A) + sinA, cosA = Y + Y_ad, ∂Y_ad = @inferred frule((Zero(), ΔA), sincos, A) + @test Y_ad == Y + @test typeof(Y_ad) === typeof(Y) + @test Y_ad[1].uplo === Y[1].uplo + @test Y_ad[2].uplo === Y[2].uplo + + @test ∂Y_ad isa Composite{typeof(Y)} + ∂Y_ad2 = Composite{typeof(Y)}( + frule((Zero(), ΔA), sin, A)[2], + frule((Zero(), ΔA), cos, A)[2], + ) + # not exact because evaluated in a more efficient way + @test ∂Y_ad ≈ ∂Y_ad2 + end + end + + @testset "rrule" begin + @testset for uplo in (:L, :U) + A = TA(randn(T, n, n), uplo) + Y = sincos(A) + sinA, cosA = Y + ΔsinA = typeof(sinA)(randn(T, n, n), sinA.uplo) + ΔcosA = typeof(cosA)(randn(T, n, n), cosA.uplo) + Y_ad, back = @inferred rrule(sincos, A) + @test Y_ad == Y + @test typeof(Y_ad) === typeof(Y) + @test Y_ad[1].uplo === Y[1].uplo + @test Y_ad[2].uplo === Y[2].uplo + + ΔY = Composite{typeof(Y)}(ΔsinA, ΔcosA) + ∂self, ∂A = @inferred back(ΔY) + @test ∂self === NO_FIELDS + @test ∂A ≈ rrule(sin, A)[2](ΔsinA)[2] + rrule(cos, A)[2](ΔcosA)[2] + + ΔY2 = Composite{typeof(Y)}(Zero(), Zero()) + @test @inferred(back(ΔY2)) === (NO_FIELDS, Zero()) + + ΔY3 = Composite{typeof(Y)}(ΔsinA, Zero()) + @test @inferred(back(ΔY3)) == rrule(sin, A)[2](ΔsinA) + + ΔY4 = Composite{typeof(Y)}(Zero(), ΔcosA) + @test @inferred(back(ΔY4)) == rrule(cos, A)[2](ΔcosA) + end + end + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index 981d4ee79..7362d68a1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,7 +3,7 @@ using ChainRules using ChainRulesCore using ChainRulesTestUtils using ChainRulesTestUtils: rand_tangent, _fdm -using Compat: only +using Compat: hasproperty, only using FiniteDifferences using FiniteDifferences: rand_tangent using SpecialFunctions