From 319860c03cbd7826095908b47397422adb8211d1 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 2 Jan 2021 20:04:02 +0100 Subject: [PATCH 1/8] Add ChainRules adjoints --- Project.toml | 6 ++++- src/StatsFuns.jl | 1 + src/distrs/beta.jl | 11 +++++++++ src/distrs/binom.jl | 12 ++++++++++ src/distrs/chisq.jl | 12 +++++++++- src/distrs/fdist.jl | 17 ++++++++++++++ src/distrs/gamma.jl | 15 +++++++++++++ src/distrs/pois.jl | 8 ++++++- src/distrs/tdist.jl | 18 ++++++++++++++- test/chainrules.jl | 54 +++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 +- 11 files changed, 151 insertions(+), 5 deletions(-) create mode 100644 test/chainrules.jl diff --git a/Project.toml b/Project.toml index 1aa3070..6b39d6a 100644 --- a/Project.toml +++ b/Project.toml @@ -3,17 +3,21 @@ uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" version = "0.9.6" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Rmath = "79098fc4-a85e-5d69-aa6a-4863f24498fa" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] +ChainRulesCore = "0.9" Rmath = "0.4, 0.5, 0.6" SpecialFunctions = "0.8, 0.9, 0.10, 1.0" julia = "1" [extras] +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ForwardDiff", "Test"] +test = ["ChainRulesTestUtils", "ForwardDiff", "Random", "Test"] diff --git a/src/StatsFuns.jl b/src/StatsFuns.jl index 922316d..839f1fd 100644 --- a/src/StatsFuns.jl +++ b/src/StatsFuns.jl @@ -4,6 +4,7 @@ module StatsFuns import Base: Math.@horner, @irrational using SpecialFunctions +import ChainRulesCore export # constants diff --git a/src/distrs/beta.jl b/src/distrs/beta.jl index 3d1bfae..cc38f7f 100644 --- a/src/distrs/beta.jl +++ b/src/distrs/beta.jl @@ -17,3 +17,14 @@ betapdf(α::Real, β::Real, x::Number) = x^(α - 1) * (1 - x)^(β - 1) / beta(α # logpdf for numbers with generic types betalogpdf(α::Real, β::Real, x::Number) = (α - 1) * log(x) + (β - 1) * log1p(-x) - logbeta(α, β) + +# ChainRules adjoint +ChainRulesCore.@scalar_rule( + betalogpdf(α::Real, β::Real, x::Number), + @setup(z = digamma(α + β)), + ( + log(x) + z - digamma(α), + log1p(-x) + z - digamma(β), + (α - 1) / x + (1 - β) / (1 - x), + ), +) diff --git a/src/distrs/binom.jl b/src/distrs/binom.jl index e1237f1..039aebf 100644 --- a/src/distrs/binom.jl +++ b/src/distrs/binom.jl @@ -17,3 +17,15 @@ binompdf(n::Real, p::Real, k::Real) = exp(binomlogpdf(n, p, k)) # logpdf for numbers with generic types binomlogpdf(n::Real, p::Real, k::Real) = -log1p(n) - logbeta(n - k + 1, k + 1) + k * log(p) + (n - k) * log1p(-p) + +# ChainRules adjoint +ChainRulesCore.@scalar_rule( + binomlogpdf(n::Real, p::Real, k::Real), + @setup(z = digamma(n - k + 1)), + ( + digamma(n + 2) - z + log1p(-p) - 1 / (1 + n), + (k / p - n) / (1 - p), + z - digamma(k + 1) + logit(p), + ), +) + diff --git a/src/distrs/chisq.jl b/src/distrs/chisq.jl index 606923e..2df517f 100644 --- a/src/distrs/chisq.jl +++ b/src/distrs/chisq.jl @@ -21,5 +21,15 @@ end # logpdf for numbers with generic types function chisqlogpdf(k::Real, x::Number) hk = k / 2 # half k - -hk * log(oftype(hk, 2)) - loggamma(hk) + (hk - 1) * log(x) - x / 2 + -hk * logtwo - loggamma(hk) + (hk - 1) * log(x) - x / 2 end + +# ChainRules adjoint +ChainRulesCore.@scalar_rule( + chisqlogpdf(k::Real, x::Number), + @setup(hk = k / 2), + ( + (log(x) - logtwo - digamma(hk)) / 2, + (hk - 1) / x - one(hk) / 2, + ), +) diff --git a/src/distrs/fdist.jl b/src/distrs/fdist.jl index b3041e3..1671f97 100644 --- a/src/distrs/fdist.jl +++ b/src/distrs/fdist.jl @@ -17,3 +17,20 @@ fdistpdf(ν1::Real, ν2::Real, x::Number) = sqrt((ν1 * x)^ν1 * ν2^ν2 / (ν1 # logpdf for numbers with generic types fdistlogpdf(ν1::Real, ν2::Real, x::Number) = (ν1 * log(ν1 * x) + ν2 * log(ν2) - (ν1 + ν2) * log(ν1 * x + ν2)) / 2 - log(x) - logbeta(ν1 / 2, ν2 / 2) + +# ChainRules adjoints +ChainRulesCore.@scalar_rule( + fdistlogpdf(ν1::Real, ν2::Real, x::Number), + @setup( + xν1 = x * ν1, + temp1 = xν1 + ν2, + a = (x - 1) / temp1, + νsum = ν1 + ν2, + di = digamma(νsum / 2), + ), + ( + (-log1p(ν2 / xν1) - ν2 * a + di - digamma(ν1 / 2)) / 2, + (-log1p(xν1 / ν2) + ν1 * a + di - digamma(ν2 / 2)) / 2, + ((ν1 - 2) / x - ν1 * νsum / temp1) / 2, + ), +) diff --git a/src/distrs/gamma.jl b/src/distrs/gamma.jl index 26dbef6..0aa7548 100644 --- a/src/distrs/gamma.jl +++ b/src/distrs/gamma.jl @@ -17,3 +17,18 @@ gammapdf(k::Real, θ::Real, x::Number) = 1 / (gamma(k) * θ^k) * x^(k - 1) * exp # logpdf for numbers with generic types gammalogpdf(k::Real, θ::Real, x::Number) = -loggamma(k) - k * log(θ) + (k - 1) * log(x) - x / θ + +# ChainRules adjoints +ChainRulesCore.@scalar_rule( + gammalogpdf(k::Real, θ::Real, x::Number), + @setup( + invθ = inv(θ), + xoθ = invθ * x, + z = xoθ - k, + ), + ( + log(xoθ) - digamma(k), + invθ * z, + - (1 + z) / x, + ), +) diff --git a/src/distrs/pois.jl b/src/distrs/pois.jl index 1c1ad3b..75d989b 100644 --- a/src/distrs/pois.jl +++ b/src/distrs/pois.jl @@ -27,4 +27,10 @@ function poislogpdf(λ::Union{Float32,Float64}, x::Union{Float64,Float32,Integer -λ else -lstirling_asym(x + 1) -=# \ No newline at end of file +=# + +# ChainRules adjoints +ChainRulesCore.@scalar_rule( + poislogpdf(λ::Number, x::Number), + (x / λ - 1, log(λ) - digamma(x + 1)), +) diff --git a/src/distrs/tdist.jl b/src/distrs/tdist.jl index 322a103..24b463c 100644 --- a/src/distrs/tdist.jl +++ b/src/distrs/tdist.jl @@ -16,4 +16,20 @@ import .RFunctions: tdistpdf(ν::Real, x::Number) = gamma((ν + 1) / 2) / (sqrt(ν * pi) * gamma(ν / 2)) * (1 + x^2 / ν)^(-(ν + 1) / 2) # logpdf for numbers with generic types -tdistlogpdf(ν::Real, x::Number) = loggamma((ν + 1) / 2) - log(ν * pi) / 2 - loggamma(ν / 2) + (-(ν + 1) / 2) * log(1 + x^2 / ν) +tdistlogpdf(ν::Real, x::Number) = loggamma((ν + 1) / 2) - log(ν * pi) / 2 - loggamma(ν / 2) + (-(ν + 1) / 2) * log1p(x^2 / ν) + +# ChainRules adjoints +ChainRulesCore.@scalar_rule( + tdistlogpdf(ν::Real, x::Number), + @setup( + νp1 = ν + 1, + xsq = x^2, + invν = inv(ν), + a = xsq * invν, + b = νp1 / (ν + xsq), + ), + ( + (digamma(νp1 / 2) - digamma(ν / 2) + a * b - log1p(a) - invν) / 2, + - x * b, + ), +) diff --git a/test/chainrules.jl b/test/chainrules.jl new file mode 100644 index 0000000..e843040 --- /dev/null +++ b/test/chainrules.jl @@ -0,0 +1,54 @@ +using StatsFuns, Test +using ChainRulesTestUtils +using Random + +# move upstream? +ChainRulesTestUtils.rand_tangent(rng::Random.AbstractRNG, x::BigFloat) = big(randn(rng)) + +@testset "chainrules" begin + x, Δx, x̄ = randn(3) + y, Δy, ȳ = randn(3) + z, Δz, z̄ = randn(3) + Δu = randn() + + x̃ = exp(x) + ỹ = exp(y) + z̃ = logistic(z) + frule_test(betalogpdf, (x̃, Δx), (ỹ, Δy), (z̃, Δz)) + rrule_test(betalogpdf, Δu, (x̃, x̄), (ỹ, ȳ), (z̃, z̄)) + + x̃ = exp(x) + ỹ = exp(y) + z̃ = exp(z) + frule_test(gammalogpdf, (x̃, Δx), (ỹ, Δy), (z̃, Δz)) + rrule_test(gammalogpdf, Δu, (x̃, x̄), (ỹ, ȳ), (z̃, z̄)) + + x̃ = exp(x) + ỹ = exp(y) + z̃ = exp(z) + frule_test(chisqlogpdf, (x̃, Δx), (ỹ, Δy)) + rrule_test(chisqlogpdf, Δu, (x̃, x̄), (ỹ, ȳ)) + + x̃ = exp(x) + ỹ = exp(y) + z̃ = exp(z) + frule_test(fdistlogpdf, (x̃, Δx), (ỹ, Δy), (z̃, Δz)) + rrule_test(fdistlogpdf, Δu, (x̃, x̄), (ỹ, ȳ), (z̃, z̄)) + + x̃ = exp(x) + frule_test(tdistlogpdf, (x̃, Δx), (y, Δy)) + rrule_test(tdistlogpdf, Δu, (x̃, x̄), (y, ȳ)) + + # use `BigFloat` to avoid Rmath implementation in finite differencing check + # (returns `NaN` for non-integer values) + x̃ = BigFloat(rand(1:100)) + ỹ = big(logistic(y)) + z̃ = BigFloat(rand(1:x̃)) + frule_test(binomlogpdf, (x̃, big(Δx)), (ỹ, big(Δy)), (z̃, big(Δz))) + rrule_test(binomlogpdf, big(Δu), (x̃, big(x̄)), (ỹ, big(ȳ)), (z̃, big(z̄))) + + x̃ = big(exp(x)) + ỹ = BigFloat(rand(1:100)) + frule_test(poislogpdf, (x̃, big(Δx)), (ỹ, big(Δy))) + rrule_test(poislogpdf, big(Δu), (x̃, big(x̄)), (ỹ, big(ȳ))) +end diff --git a/test/runtests.jl b/test/runtests.jl index b549ad5..ec3a1ce 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -tests = ["basicfuns", "rmath", "generic", "misc"] +tests = ["basicfuns", "rmath", "generic", "misc", "chainrules"] for t in tests fp = "$t.jl" From 2a660dbb3fac795e156fb2edf1f9df8bd611de78 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 9 Jan 2021 00:13:34 +0100 Subject: [PATCH 2/8] Move differentiation rules to a separate file --- src/StatsFuns.jl | 2 ++ src/chainrules.jl | 78 +++++++++++++++++++++++++++++++++++++++++++++ src/distrs/beta.jl | 11 ------- src/distrs/binom.jl | 12 ------- src/distrs/chisq.jl | 10 ------ src/distrs/fdist.jl | 17 ---------- src/distrs/gamma.jl | 15 --------- src/distrs/pois.jl | 6 ---- src/distrs/tdist.jl | 16 ---------- 9 files changed, 80 insertions(+), 87 deletions(-) create mode 100644 src/chainrules.jl diff --git a/src/StatsFuns.jl b/src/StatsFuns.jl index 839f1fd..583f879 100644 --- a/src/StatsFuns.jl +++ b/src/StatsFuns.jl @@ -259,4 +259,6 @@ include(joinpath("distrs", "pois.jl")) include(joinpath("distrs", "tdist.jl")) include(joinpath("distrs", "srdist.jl")) +include("chainrules.jl") + end # module diff --git a/src/chainrules.jl b/src/chainrules.jl new file mode 100644 index 0000000..43045e2 --- /dev/null +++ b/src/chainrules.jl @@ -0,0 +1,78 @@ +ChainRulesCore.@scalar_rule( + betalogpdf(α::Real, β::Real, x::Number), + @setup(z = digamma(α + β)), + ( + log(x) + z - digamma(α), + log1p(-x) + z - digamma(β), + (α - 1) / x + (1 - β) / (1 - x), + ), +) + +ChainRulesCore.@scalar_rule( + binomlogpdf(n::Real, p::Real, k::Real), + @setup(z = digamma(n - k + 1)), + ( + digamma(n + 2) - z + log1p(-p) - 1 / (1 + n), + (k / p - n) / (1 - p), + z - digamma(k + 1) + logit(p), + ), +) + +ChainRulesCore.@scalar_rule( + chisqlogpdf(k::Real, x::Number), + @setup(hk = k / 2), + ( + (log(x) - logtwo - digamma(hk)) / 2, + (hk - 1) / x - one(hk) / 2, + ), +) + +ChainRulesCore.@scalar_rule( + fdistlogpdf(ν1::Real, ν2::Real, x::Number), + @setup( + xν1 = x * ν1, + temp1 = xν1 + ν2, + a = (x - 1) / temp1, + νsum = ν1 + ν2, + di = digamma(νsum / 2), + ), + ( + (-log1p(ν2 / xν1) - ν2 * a + di - digamma(ν1 / 2)) / 2, + (-log1p(xν1 / ν2) + ν1 * a + di - digamma(ν2 / 2)) / 2, + ((ν1 - 2) / x - ν1 * νsum / temp1) / 2, + ), +) + +ChainRulesCore.@scalar_rule( + gammalogpdf(k::Real, θ::Real, x::Number), + @setup( + invθ = inv(θ), + xoθ = invθ * x, + z = xoθ - k, + ), + ( + log(xoθ) - digamma(k), + invθ * z, + - (1 + z) / x, + ), +) + +ChainRulesCore.@scalar_rule( + poislogpdf(λ::Number, x::Number), + (x / λ - 1, log(λ) - digamma(x + 1)), +) + +ChainRulesCore.@scalar_rule( + tdistlogpdf(ν::Real, x::Number), + @setup( + νp1 = ν + 1, + xsq = x^2, + invν = inv(ν), + a = xsq * invν, + b = νp1 / (ν + xsq), + ), + ( + (digamma(νp1 / 2) - digamma(ν / 2) + a * b - log1p(a) - invν) / 2, + - x * b, + ), +) diff --git a/src/distrs/beta.jl b/src/distrs/beta.jl index cc38f7f..3d1bfae 100644 --- a/src/distrs/beta.jl +++ b/src/distrs/beta.jl @@ -17,14 +17,3 @@ betapdf(α::Real, β::Real, x::Number) = x^(α - 1) * (1 - x)^(β - 1) / beta(α # logpdf for numbers with generic types betalogpdf(α::Real, β::Real, x::Number) = (α - 1) * log(x) + (β - 1) * log1p(-x) - logbeta(α, β) - -# ChainRules adjoint -ChainRulesCore.@scalar_rule( - betalogpdf(α::Real, β::Real, x::Number), - @setup(z = digamma(α + β)), - ( - log(x) + z - digamma(α), - log1p(-x) + z - digamma(β), - (α - 1) / x + (1 - β) / (1 - x), - ), -) diff --git a/src/distrs/binom.jl b/src/distrs/binom.jl index 039aebf..e1237f1 100644 --- a/src/distrs/binom.jl +++ b/src/distrs/binom.jl @@ -17,15 +17,3 @@ binompdf(n::Real, p::Real, k::Real) = exp(binomlogpdf(n, p, k)) # logpdf for numbers with generic types binomlogpdf(n::Real, p::Real, k::Real) = -log1p(n) - logbeta(n - k + 1, k + 1) + k * log(p) + (n - k) * log1p(-p) - -# ChainRules adjoint -ChainRulesCore.@scalar_rule( - binomlogpdf(n::Real, p::Real, k::Real), - @setup(z = digamma(n - k + 1)), - ( - digamma(n + 2) - z + log1p(-p) - 1 / (1 + n), - (k / p - n) / (1 - p), - z - digamma(k + 1) + logit(p), - ), -) - diff --git a/src/distrs/chisq.jl b/src/distrs/chisq.jl index 2df517f..25a069d 100644 --- a/src/distrs/chisq.jl +++ b/src/distrs/chisq.jl @@ -23,13 +23,3 @@ function chisqlogpdf(k::Real, x::Number) hk = k / 2 # half k -hk * logtwo - loggamma(hk) + (hk - 1) * log(x) - x / 2 end - -# ChainRules adjoint -ChainRulesCore.@scalar_rule( - chisqlogpdf(k::Real, x::Number), - @setup(hk = k / 2), - ( - (log(x) - logtwo - digamma(hk)) / 2, - (hk - 1) / x - one(hk) / 2, - ), -) diff --git a/src/distrs/fdist.jl b/src/distrs/fdist.jl index 1671f97..b3041e3 100644 --- a/src/distrs/fdist.jl +++ b/src/distrs/fdist.jl @@ -17,20 +17,3 @@ fdistpdf(ν1::Real, ν2::Real, x::Number) = sqrt((ν1 * x)^ν1 * ν2^ν2 / (ν1 # logpdf for numbers with generic types fdistlogpdf(ν1::Real, ν2::Real, x::Number) = (ν1 * log(ν1 * x) + ν2 * log(ν2) - (ν1 + ν2) * log(ν1 * x + ν2)) / 2 - log(x) - logbeta(ν1 / 2, ν2 / 2) - -# ChainRules adjoints -ChainRulesCore.@scalar_rule( - fdistlogpdf(ν1::Real, ν2::Real, x::Number), - @setup( - xν1 = x * ν1, - temp1 = xν1 + ν2, - a = (x - 1) / temp1, - νsum = ν1 + ν2, - di = digamma(νsum / 2), - ), - ( - (-log1p(ν2 / xν1) - ν2 * a + di - digamma(ν1 / 2)) / 2, - (-log1p(xν1 / ν2) + ν1 * a + di - digamma(ν2 / 2)) / 2, - ((ν1 - 2) / x - ν1 * νsum / temp1) / 2, - ), -) diff --git a/src/distrs/gamma.jl b/src/distrs/gamma.jl index 0aa7548..26dbef6 100644 --- a/src/distrs/gamma.jl +++ b/src/distrs/gamma.jl @@ -17,18 +17,3 @@ gammapdf(k::Real, θ::Real, x::Number) = 1 / (gamma(k) * θ^k) * x^(k - 1) * exp # logpdf for numbers with generic types gammalogpdf(k::Real, θ::Real, x::Number) = -loggamma(k) - k * log(θ) + (k - 1) * log(x) - x / θ - -# ChainRules adjoints -ChainRulesCore.@scalar_rule( - gammalogpdf(k::Real, θ::Real, x::Number), - @setup( - invθ = inv(θ), - xoθ = invθ * x, - z = xoθ - k, - ), - ( - log(xoθ) - digamma(k), - invθ * z, - - (1 + z) / x, - ), -) diff --git a/src/distrs/pois.jl b/src/distrs/pois.jl index 75d989b..3caffc2 100644 --- a/src/distrs/pois.jl +++ b/src/distrs/pois.jl @@ -28,9 +28,3 @@ function poislogpdf(λ::Union{Float32,Float64}, x::Union{Float64,Float32,Integer else -lstirling_asym(x + 1) =# - -# ChainRules adjoints -ChainRulesCore.@scalar_rule( - poislogpdf(λ::Number, x::Number), - (x / λ - 1, log(λ) - digamma(x + 1)), -) diff --git a/src/distrs/tdist.jl b/src/distrs/tdist.jl index 24b463c..5dc83e9 100644 --- a/src/distrs/tdist.jl +++ b/src/distrs/tdist.jl @@ -17,19 +17,3 @@ tdistpdf(ν::Real, x::Number) = gamma((ν + 1) / 2) / (sqrt(ν * pi) * gamma(ν # logpdf for numbers with generic types tdistlogpdf(ν::Real, x::Number) = loggamma((ν + 1) / 2) - log(ν * pi) / 2 - loggamma(ν / 2) + (-(ν + 1) / 2) * log1p(x^2 / ν) - -# ChainRules adjoints -ChainRulesCore.@scalar_rule( - tdistlogpdf(ν::Real, x::Number), - @setup( - νp1 = ν + 1, - xsq = x^2, - invν = inv(ν), - a = xsq * invν, - b = νp1 / (ν + xsq), - ), - ( - (digamma(νp1 / 2) - digamma(ν / 2) + a * b - log1p(a) - invν) / 2, - - x * b, - ), -) From 70ca051618add8f43a166e7626f01983b1dbba3d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 31 Mar 2021 02:43:56 +0200 Subject: [PATCH 3/8] Update to new test syntax --- test/chainrules.jl | 80 ++++++++++++++++++++++------------------------ 1 file changed, 38 insertions(+), 42 deletions(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index e843040..717cbbc 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -6,49 +6,45 @@ using Random ChainRulesTestUtils.rand_tangent(rng::Random.AbstractRNG, x::BigFloat) = big(randn(rng)) @testset "chainrules" begin - x, Δx, x̄ = randn(3) - y, Δy, ȳ = randn(3) - z, Δz, z̄ = randn(3) - Δu = randn() - - x̃ = exp(x) - ỹ = exp(y) - z̃ = logistic(z) - frule_test(betalogpdf, (x̃, Δx), (ỹ, Δy), (z̃, Δz)) - rrule_test(betalogpdf, Δu, (x̃, x̄), (ỹ, ȳ), (z̃, z̄)) - - x̃ = exp(x) - ỹ = exp(y) - z̃ = exp(z) - frule_test(gammalogpdf, (x̃, Δx), (ỹ, Δy), (z̃, Δz)) - rrule_test(gammalogpdf, Δu, (x̃, x̄), (ỹ, ȳ), (z̃, z̄)) - - x̃ = exp(x) - ỹ = exp(y) - z̃ = exp(z) - frule_test(chisqlogpdf, (x̃, Δx), (ỹ, Δy)) - rrule_test(chisqlogpdf, Δu, (x̃, x̄), (ỹ, ȳ)) - - x̃ = exp(x) - ỹ = exp(y) - z̃ = exp(z) - frule_test(fdistlogpdf, (x̃, Δx), (ỹ, Δy), (z̃, Δz)) - rrule_test(fdistlogpdf, Δu, (x̃, x̄), (ỹ, ȳ), (z̃, z̄)) - - x̃ = exp(x) - frule_test(tdistlogpdf, (x̃, Δx), (y, Δy)) - rrule_test(tdistlogpdf, Δu, (x̃, x̄), (y, ȳ)) + x = exp(randn()) + y = exp(randn()) + z = logistic(randn()) + test_frule(betalogpdf, x, y, z) + test_rrule(betalogpdf, x, y, z) + + x = exp(randn()) + y = exp(randn()) + z = exp(randn()) + test_frule(gammalogpdf, x, y, z) + test_rrule(gammalogpdf, x, y, z) + + x = exp(randn()) + y = exp(randn()) + test_frule(chisqlogpdf, x, y) + test_rrule(chisqlogpdf, x, y) + + x = exp(randn()) + y = exp(randn()) + z = exp(randn()) + test_frule(fdistlogpdf, x, y, z) + test_rrule(fdistlogpdf, x, y, z) + + x = exp(randn()) + y = randn() + test_frule(tdistlogpdf, x, y) + test_rrule(tdistlogpdf, x, y) # use `BigFloat` to avoid Rmath implementation in finite differencing check # (returns `NaN` for non-integer values) - x̃ = BigFloat(rand(1:100)) - ỹ = big(logistic(y)) - z̃ = BigFloat(rand(1:x̃)) - frule_test(binomlogpdf, (x̃, big(Δx)), (ỹ, big(Δy)), (z̃, big(Δz))) - rrule_test(binomlogpdf, big(Δu), (x̃, big(x̄)), (ỹ, big(ȳ)), (z̃, big(z̄))) - - x̃ = big(exp(x)) - ỹ = BigFloat(rand(1:100)) - frule_test(poislogpdf, (x̃, big(Δx)), (ỹ, big(Δy))) - rrule_test(poislogpdf, big(Δu), (x̃, big(x̄)), (ỹ, big(ȳ))) + n = rand(1:100) + x = BigFloat(n) + y = big(logistic(randn())) + z = BigFloat(rand(1:n)) + test_frule(binomlogpdf, x, y, z) + test_rrule(binomlogpdf, x, y, z) + + x = big(exp(randn())) + y = BigFloat(rand(1:100)) + test_frule(poislogpdf, x, y) + test_rrule(poislogpdf, x, y) end From f1e92f03b0ab0247b6ce3faf798c27aca37d0dff Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 31 Mar 2021 13:12:39 +0200 Subject: [PATCH 4/8] `rand_tangent` is fixed upstream --- test/chainrules.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index 717cbbc..5163102 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -2,9 +2,6 @@ using StatsFuns, Test using ChainRulesTestUtils using Random -# move upstream? -ChainRulesTestUtils.rand_tangent(rng::Random.AbstractRNG, x::BigFloat) = big(randn(rng)) - @testset "chainrules" begin x = exp(randn()) y = exp(randn()) From 50af675ae4855fba30f6c16dbd553da5460f7b0f Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 2 Jun 2021 12:20:12 +0200 Subject: [PATCH 5/8] Add support for ChainRulesCore 0.10 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9a6c358..10059af 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,7 @@ Rmath = "79098fc4-a85e-5d69-aa6a-4863f24498fa" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] -ChainRulesCore = "0.9" +ChainRulesCore = "0.9, 0.10" LogExpFunctions = "0.2.1" Rmath = "0.4, 0.5, 0.6, 0.7" SpecialFunctions = "0.8, 0.9, 0.10, 1.0" From f154f25e2ee1c880239c93192e51c6b5528dff57 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 13 Aug 2021 16:16:13 +0200 Subject: [PATCH 6/8] Fix definition of chainrule for `poislogpdf` --- src/chainrules.jl | 2 +- test/chainrules.jl | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 43045e2..f405c21 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -59,7 +59,7 @@ ChainRulesCore.@scalar_rule( ChainRulesCore.@scalar_rule( poislogpdf(λ::Number, x::Number), - (x / λ - 1, log(λ) - digamma(x + 1)), + ((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, log(λ) - digamma(x + 1)), ) ChainRulesCore.@scalar_rule( diff --git a/test/chainrules.jl b/test/chainrules.jl index 5163102..e469ebf 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -1,4 +1,5 @@ using StatsFuns, Test +using ChainRulesCore using ChainRulesTestUtils using Random @@ -44,4 +45,12 @@ using Random y = BigFloat(rand(1:100)) test_frule(poislogpdf, x, y) test_rrule(poislogpdf, x, y) + + # test special case λ = 0 + _, pb = rrule(StatsFuns.poislogpdf, 0.0, 0.0) + _, x̄1, _ = pb(1) + @test x̄1 == -1 + _, pb = rrule(StatsFuns.poislogpdf, 0.0, 1.0) + _, x̄1, _ = pb(1) + @test x̄1 == Inf end From ce79c00ab9f30bdb2318d118d16648a50173d70b Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 13 Aug 2021 16:20:15 +0200 Subject: [PATCH 7/8] Use ChainRulesCore 1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ad7d37d..5c460f9 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ Rmath = "79098fc4-a85e-5d69-aa6a-4863f24498fa" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] -ChainRulesCore = "0.9, 0.10" +ChainRulesCore = "0.9, 0.10, 1" IrrationalConstants = "0.1" LogExpFunctions = "0.3" Reexport = "1" From 81ee96064f1c47671847adf2045ad5ad70e8db07 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 25 Aug 2021 15:23:03 +0200 Subject: [PATCH 8/8] Only support ChainRulesCore 1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5c460f9..91b946e 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ Rmath = "79098fc4-a85e-5d69-aa6a-4863f24498fa" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] -ChainRulesCore = "0.9, 0.10, 1" +ChainRulesCore = "1" IrrationalConstants = "0.1" LogExpFunctions = "0.3" Reexport = "1"