Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ChainRules definitions and add differential for PoissonBinomial pdf #162

Merged
merged 6 commits into from
Apr 25, 2021

Conversation

devmotion
Copy link
Member

The PR updates the ChainRules definitions and tests, according to JuliaStats/StatsFuns.jl#106 (which ideally would allow us to remove the definitions in DistributionsAD).

Additionally, I derived and implemented the differential for the pdf of the PoissonBinomial distribution. It uses a similar dynamic programming style as the implementation of the pdf in Distributions. Unfortunately, the recently added lazy initialization of PoissonBinomial (JuliaStats/Distributions.jl#1285) breaks Zygote support. I didn't manage to fix it with a custom differential for getproperty (or the Zygote analogue literal_getproperty). E.g., the following caused errors

    # this rrule is necessary since the primal mutates
    function ChainRulesCore.rrule(
        ::typeof(getproperty),
        d::T,
        x::Symbol,
    ) where {T<:PoissonBinomial}
        y, A = if getfield(d, :pmf) === nothing && x === :pmf
            getproperty(d, x), poissonbinomial_partialderivatives(d.p)
        else
            getproperty(d, x), nothing
        end
        
        function getproperty_PoissonBinomial_pullback(Δy)
            ∂d = if x === :pmf && A !== nothing
                Composite{T}(; pmf=A * Δy)
            else
                DoesNotExist()
            end
            return NO_FIELDS, ∂d, DoesNotExist()
        end

        return y, getproperty_PoissonBinomial_pullback
    end

similar to those reported in FluxML/Zygote.jl#566. I think this issue needs more time and therefore I marked the Zygote tests of PoissonBinomial as broken.

@devmotion devmotion requested a review from torfjelde April 22, 2021 16:47
Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
I haven't checked all the derivatives in detail, but that's what the tests are for 🙃

I have a couple of questions, but they're mostly just questions; don't think they'll require any change, so I just approve now 👍

src/chainrules.jl Show resolved Hide resolved
(@thunk((-logtwo - digamma(ko2) + log(x)) / 2), @thunk((ko2 - 1)/x - one(ko2) / 2)),
@setup(hk = k / 2),
(
(log(x) - logtwo - digamma(hk)) / 2,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is logtwo from? I see it's from StatsFuns, but I can't find the definition in StatsFuns 😕

I ask because I worry it might lead to undeseriable type-promotion, e.g. Float32 to Float64.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's defined in LogExpFunctions and just reexported from StatsFuns (the PR was finally merged some days ago 🎉). It is defined as an Irrational and hence should avoid type promotions if possible, e.g:

julia> logtwo - 3.4
-2.7068528194400545

julia> logtwo - 3.4f0
-2.706853f0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good, good 👍 Yeah in this case it should be fine, but there are several cases in DiffRules.jl (and I'm pretty certain the same ones exist in ChainRules.jl?) where the operations are written in a way such that Irrational is first converted into Float64, and thus the dual is promoted to Float64, despite the primal being Float32 (JuliaDiff/DiffRules.jl#55).

And the annoying bit is that some of these constants are defined in StatsFuns.jl, some in LogExpFunctions.jl, etc., so it's difficult to re-use them in DiffRules.jl 😕
Though maybe I should actually just add all those constants to that PR instead of doing oftype everywhere, hmm.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only a problem if you perform operations with the constants (such as sqrt) before they are promoted to the correct type, which is not the case here.

@devmotion
Copy link
Member Author

I haven't checked all the derivatives in detail, but that's what the tests are for

On a side note, the testing utilities in ChainRulesTestUtils are really great (and it's also documented now), it makes it very easy to check the differentials and to ensure that they are correct. And it even warns you if e.g. you use a thunk but there is only one derivative.

@devmotion devmotion requested a review from torfjelde April 25, 2021 15:17
@devmotion devmotion merged commit c463960 into master Apr 25, 2021
@devmotion devmotion deleted the dw/chainrules branch April 25, 2021 18:07
@devmotion devmotion mentioned this pull request Apr 26, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants