-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update ChainRules definitions and add differential for PoissonBinomial pdf #162
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
I haven't checked all the derivatives in detail, but that's what the tests are for 🙃
I have a couple of questions, but they're mostly just questions; don't think they'll require any change, so I just approve now 👍
(@thunk((-logtwo - digamma(ko2) + log(x)) / 2), @thunk((ko2 - 1)/x - one(ko2) / 2)), | ||
@setup(hk = k / 2), | ||
( | ||
(log(x) - logtwo - digamma(hk)) / 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is logtwo
from? I see it's from StatsFuns, but I can't find the definition in StatsFuns 😕
I ask because I worry it might lead to undeseriable type-promotion, e.g. Float32
to Float64
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's defined in LogExpFunctions and just reexported from StatsFuns (the PR was finally merged some days ago 🎉). It is defined as an Irrational
and hence should avoid type promotions if possible, e.g:
julia> logtwo - 3.4
-2.7068528194400545
julia> logtwo - 3.4f0
-2.706853f0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good, good 👍 Yeah in this case it should be fine, but there are several cases in DiffRules.jl (and I'm pretty certain the same ones exist in ChainRules.jl?) where the operations are written in a way such that Irrational
is first converted into Float64
, and thus the dual is promoted to Float64
, despite the primal being Float32
(JuliaDiff/DiffRules.jl#55).
And the annoying bit is that some of these constants are defined in StatsFuns.jl, some in LogExpFunctions.jl, etc., so it's difficult to re-use them in DiffRules.jl 😕
Though maybe I should actually just add all those constants to that PR instead of doing oftype
everywhere, hmm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's only a problem if you perform operations with the constants (such as sqrt
) before they are promoted to the correct type, which is not the case here.
On a side note, the testing utilities in ChainRulesTestUtils are really great (and it's also documented now), it makes it very easy to check the differentials and to ensure that they are correct. And it even warns you if e.g. you use a thunk but there is only one derivative. |
The PR updates the ChainRules definitions and tests, according to JuliaStats/StatsFuns.jl#106 (which ideally would allow us to remove the definitions in DistributionsAD).
Additionally, I derived and implemented the differential for the pdf of the PoissonBinomial distribution. It uses a similar dynamic programming style as the implementation of the pdf in Distributions. Unfortunately, the recently added lazy initialization of
PoissonBinomial
(JuliaStats/Distributions.jl#1285) breaks Zygote support. I didn't manage to fix it with a custom differential forgetproperty
(or the Zygote analogueliteral_getproperty
). E.g., the following caused errorssimilar to those reported in FluxML/Zygote.jl#566. I think this issue needs more time and therefore I marked the Zygote tests of PoissonBinomial as broken.