Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_prior_kl not differentiable for Centered with Zygote #129

Open
theogf opened this issue Apr 27, 2022 · 14 comments
Open

_prior_kl not differentiable for Centered with Zygote #129

theogf opened this issue Apr 27, 2022 · 14 comments
Labels
AD issue Problem with automatic differentiation bug Something isn't working help wanted Extra attention is needed

Comments

@theogf
Copy link
Member

theogf commented Apr 27, 2022

Here is simplified view of the problem from @simsurace:

using Distributions
using Zygote

function DKL(par1, par2)
    K1 = [par1[1] par1[2]; par1[2] par1[1]]
    K2 = [par2[1] par2[2]; par2[2] par2[1]]
    return kldivergence(
        MvNormal(K1),
        MvNormal(K2)
    )
end

Zygote.gradient(par2 -> DKL([1., 0.1], par2), [1., 0.1])

With the following error:

ERROR: MethodError: no method matching +(::NamedTuple{(:dim, :mat, :chol), Tuple{Nothing, Nothing, NamedTuple{(:factors, :uplo, :info), Tuple{LinearAlgebra.Diagonal{Float64, Vector{Float64}}, Nothing, Nothing}}}}, ::Matrix{Float64})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at ~/julia-1.7.1/share/julia/base/operators.jl:655
  +(::FillArrays.Zeros{T, N}, ::AbstractArray{V, N}) where {T, V, N} at ~/.julia/packages/FillArrays/5Arin/src/fillalgebra.jl:228
  +(::Tangent{P}, ::P) where P at ~/.julia/packages/ChainRulesCore/RbX5a/src/tangent_arithmetic.jl:146
  ...
Stacktrace:
  [1] accum(x::NamedTuple{(:dim, :mat, :chol), Tuple{Nothing, Nothing, NamedTuple{(:factors, :uplo, :info), Tuple{LinearAlgebra.Diagonal{Float64, Vector{Float64}}, Nothing, Nothing}}}}, y::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:17
  [2] macro expansion
    @ ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:27 [inlined]
  [3] accum(x::NamedTuple{(, ), Tuple{Vector{Float64}, NamedTuple{(:dim, :mat, :chol), Tuple{Nothing, Nothing, NamedTuple{(:factors, :uplo, :info), Tuple{LinearAlgebra.Diagonal{Float64, Vector{Float64}}, Nothing, Nothing}}}}}}, y::NamedTuple{(:μ, :Σ), Tuple{Nothing, Matrix{Float64}}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:27
  [4] accum(x::NamedTuple{(, ), Tuple{Nothing, NamedTuple{(:dim, :mat, :chol), Tuple{Nothing, Nothing, NamedTuple{(:factors, :uplo, :info), Tuple{LinearAlgebra.Diagonal{Float64, Vector{Float64}}, Nothing, Nothing}}}}}}, y::NamedTuple{(:μ, :Σ), Tuple{Vector{Float64}, Nothing}}, zs::NamedTuple{(:μ, :Σ), Tuple{Nothing, Matrix{Float64}}}) (repeats 2 times)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:22
  [5] Pullback
    @ ~/.julia/packages/Distributions/O4ZJg/src/multivariate/mvnormal.jl:110 [inlined]
  [6] (::typeof((kldivergence)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
  [7] Pullback
    @ ./REPL[16]:4 [inlined]
  [8] (::typeof((DKL)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
  [9] Pullback
    @ ./REPL[18]:1 [inlined]
 [10] (::typeof((#5)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [11] (::Zygote.var"#56#57"{typeof((#5))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:41
 [12] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:76
 [13] top-level scope
    @ REPL[18]:1
@theogf
Copy link
Member Author

theogf commented Apr 27, 2022

@theogf
Copy link
Member Author

theogf commented Apr 27, 2022

I was thinking maybe the easiest (and maybe cheapest!) would be to directly write the rrule for kldivergence

@devmotion
Copy link
Member

This seems to be yet another AD issue with PDMats. Maybe about time to add CR to that repo.

@devmotion
Copy link
Member

Regardless, it might still be useful and possibly more efficient to add a CR definition for kl_divergence directly.

@theogf
Copy link
Member Author

theogf commented Apr 27, 2022

As usual, I have the forward rule, but I don't have the brain capacity to derive the reverse one.

@theogf
Copy link
Member Author

theogf commented Apr 27, 2022

This seems to be yet another AD issue with PDMats. Maybe about time to add CR to that repo.

I thought there was some friction for this but I cannot find the discussion

EDIT: Ah no! That's in Distances.jl

@theogf theogf added bug Something isn't working help wanted Extra attention is needed AD issue Problem with automatic differentiation labels Apr 27, 2022
@st--
Copy link
Member

st-- commented Apr 28, 2022

As usual, I have the forward rule, but I don't have the brain capacity to derive the reverse one.

If you share the forward rule I can see if I manage to sort out the rrule:)

@willtebbutt
Copy link
Member

This is an example of one rrule (for some function utilised in kldivergence) returning a Diagonal matrix (a "natural" tangent), and another digging into the PDMat and returning a NamedTuple (the "structural" tangent). You can deduce that this kind of thing is going on from the call to accum with a NamedTuple and a Diagonal -- generally speaking, accum is roughly equivalent to +, so if it's not obvious how to add two types, accum probably won't work for them without manual intervention.

There are two flavours of fix for this kind of problem:

  1. prevent the accum error by ensuring that both tangents are converted to a common type before hitting the call to accum, and
  2. implement accum for the two types in question.

In my view the former is the way to go, and it is what the projection mechanism in CR lets you do.

If I had to guess, I would say that _cov(q) \ _cov(p) is giving the natural, and logdetcov(q) / logdetcov(p) the structural, but you would have to check.

The fix is probably to get PDMats onto CR, as @devmotion suggests, and presumably to implement the projection mechanism for it.

@simsurace
Copy link
Member

I tried the terms in kldivergence one by one:

using Distributions
using ForwardDiff
using Zygote

function kernel(x)
    return [1. x; x 1.]
end

d(x) = MvNormal(kernel(x))

kldivergence(d(0.1), d(0.0))

f(x) = kldivergence(d(x), d(0.0))
ForwardDiff.gradient(x->f(only(x)), [.1]) # works
Zygote.gradient(x->f(only(x)), [.1]) # ERROR

g(x) = logdetcov(d(x))
Zygote.gradient(x->g(only(x)), [.1]) # works

g(x) = sqmahal(d(x), zeros(2))
Zygote.gradient(x->g(only(x)), [.1]) # works

g(x) = length(d(x))
Zygote.gradient(x->g(only(x)), [.1]) # works

g(x) = tr(cov(d(0.0)) \ cov(d(x)))
Zygote.gradient(x->g(only(x)), [.1]) # ERROR

g(x) = (tr(cov(d(0.0)) \ cov(d(x))) + sqmahal(d(0.0), mean(d(x))) - length(d(x)) + logdetcov(d(0.0)) - logdetcov(d(x))) / 2
Zygote.gradient(x->g(only(x)), [.1]) # ERROR

Interestingly, the last two errors are different than the first:

Full test output

ERROR: Need an adjoint for constructor PDMats.PDMat{Float64, Matrix{Float64}}. Gradient is of type Diagonal{Float64, Vector{Float64}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.Jnew{PDMats.PDMat{Float64, Matrix{Float64}}, Nothing, false})(Δ::Diagonal{Float64, Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:324
  [3] (::Zygote.var"#1784#back#228"{Zygote.Jnew{PDMats.PDMat{Float64, Matrix{Float64}}, Nothing, false}})(Δ::Diagonal{Float64, Vector{Float64}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ ~/.julia/packages/PDMats/mudzk/src/pdmat.jl:9 [inlined]
  [5] (::typeof(∂(PDMats.PDMat{Float64, Matrix{Float64}})))(Δ::Diagonal{Float64, Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
  [6] Pullback
    @ ~/.julia/packages/PDMats/mudzk/src/pdmat.jl:16 [inlined]
  [7] (::typeof(∂(PDMats.PDMat)))(Δ::Diagonal{Float64, Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/PDMats/mudzk/src/pdmat.jl:19 [inlined]
  [9] (::typeof(∂(PDMats.PDMat)))(Δ::Diagonal{Float64, Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/packages/Distributions/O4ZJg/src/multivariate/mvnormal.jl:201 [inlined]
 [11] (::typeof(∂(MvNormal)))(Δ::NamedTuple{(:μ, :Σ), Tuple{Nothing, Diagonal{Float64, Vector{Float64}}}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/.julia/packages/Distributions/O4ZJg/src/multivariate/mvnormal.jl:218 [inlined]
 [13] (::typeof(∂(MvNormal)))(Δ::NamedTuple{(:μ, :Σ), Tuple{Nothing, Diagonal{Float64, Vector{Float64}}}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/Documents/projects/n/ai-timeseries-prototypes/autodiff/kldiff.jl:9 [inlined]
 [15] (::typeof(∂(d)))(Δ::NamedTuple{(:μ, :Σ), Tuple{Nothing, Diagonal{Float64, Vector{Float64}}}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [16] Pullback
    @ ./REPL[40]:1 [inlined]
 [17] (::typeof(∂(g)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [18] Pullback
    @ ./REPL[41]:1 [inlined]
 [19] (::typeof(∂(#33)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [20] (::Zygote.var"#56#57"{typeof(∂(#33))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:41
 [21] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:76
 [22] top-level scope
    @ REPL[41]:1

What's going on here?

@devmotion
Copy link
Member

All of them are PDMats issues, it seems.

@simsurace
Copy link
Member

I was expecting the + issue to show up when I assemble the terms manually. Instead, the last two errors seem to fail at an earlier stage.

@simsurace
Copy link
Member

I opened an issue JuliaStats/PDMats.jl#159

@st--
Copy link
Member

st-- commented Sep 12, 2022

Despite JuliaDiff/ChainRules.jl#613 this seems to still be broken :(

@devmotion
Copy link
Member

The PDMats issues (JuliaStats/PDMats.jl#159) are not fixed yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AD issue Problem with automatic differentiation bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

5 participants