-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
_prior_kl
not differentiable for Centered
with Zygote
#129
Comments
The solution from https://github.com/theogf/KLDivergences.jl/blob/main/src/horrible_ad_workaround.jl does not work |
I was thinking maybe the easiest (and maybe cheapest!) would be to directly write the |
This seems to be yet another AD issue with PDMats. Maybe about time to add CR to that repo. |
Regardless, it might still be useful and possibly more efficient to add a CR definition for |
As usual, I have the forward rule, but I don't have the brain capacity to derive the reverse one. |
I thought there was some friction for this but I cannot find the discussion EDIT: Ah no! That's in Distances.jl |
If you share the forward rule I can see if I manage to sort out the rrule:) |
This is an example of one There are two flavours of fix for this kind of problem:
In my view the former is the way to go, and it is what the projection mechanism in CR lets you do. If I had to guess, I would say that The fix is probably to get |
I tried the terms in using Distributions
using ForwardDiff
using Zygote
function kernel(x)
return [1. x; x 1.]
end
d(x) = MvNormal(kernel(x))
kldivergence(d(0.1), d(0.0))
f(x) = kldivergence(d(x), d(0.0))
ForwardDiff.gradient(x->f(only(x)), [.1]) # works
Zygote.gradient(x->f(only(x)), [.1]) # ERROR
g(x) = logdetcov(d(x))
Zygote.gradient(x->g(only(x)), [.1]) # works
g(x) = sqmahal(d(x), zeros(2))
Zygote.gradient(x->g(only(x)), [.1]) # works
g(x) = length(d(x))
Zygote.gradient(x->g(only(x)), [.1]) # works
g(x) = tr(cov(d(0.0)) \ cov(d(x)))
Zygote.gradient(x->g(only(x)), [.1]) # ERROR
g(x) = (tr(cov(d(0.0)) \ cov(d(x))) + sqmahal(d(0.0), mean(d(x))) - length(d(x)) + logdetcov(d(0.0)) - logdetcov(d(x))) / 2
Zygote.gradient(x->g(only(x)), [.1]) # ERROR Interestingly, the last two errors are different than the first: Full test output
What's going on here? |
All of them are PDMats issues, it seems. |
I was expecting the + issue to show up when I assemble the terms manually. Instead, the last two errors seem to fail at an earlier stage. |
I opened an issue JuliaStats/PDMats.jl#159 |
Despite JuliaDiff/ChainRules.jl#613 this seems to still be broken :( |
The PDMats issues (JuliaStats/PDMats.jl#159) are not fixed yet. |
Here is simplified view of the problem from @simsurace:
With the following error:
The text was updated successfully, but these errors were encountered: