Skip to content

Commit

Permalink
Add ChainRules adjoints (#106)
Browse files Browse the repository at this point in the history
* Add ChainRules adjoints

* Move differentiation rules to a separate file

* Update to new test syntax

* `rand_tangent` is fixed upstream

* Add support for ChainRulesCore 0.10

* Fix definition of chainrule for `poislogpdf`

* Use ChainRulesCore 1

* Only support ChainRulesCore 1
  • Loading branch information
devmotion authored Aug 30, 2021
1 parent 36e48bb commit 3aced77
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 5 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.9.9"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Rmath = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
ChainRulesCore = "1"
IrrationalConstants = "0.1"
LogExpFunctions = "0.3"
Reexport = "1"
Expand All @@ -18,8 +20,10 @@ SpecialFunctions = "0.8, 0.9, 0.10, 1.0"
julia = "1"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ForwardDiff", "Test"]
test = ["ChainRulesTestUtils", "ForwardDiff", "Random", "Test"]
3 changes: 3 additions & 0 deletions src/StatsFuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module StatsFuns
using Base: Math.@horner
using Reexport
using SpecialFunctions
import ChainRulesCore

# reexports
@reexport using IrrationalConstants:
Expand Down Expand Up @@ -257,4 +258,6 @@ include(joinpath("distrs", "pois.jl"))
include(joinpath("distrs", "tdist.jl"))
include(joinpath("distrs", "srdist.jl"))

include("chainrules.jl")

end # module
78 changes: 78 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
ChainRulesCore.@scalar_rule(
betalogpdf::Real, β::Real, x::Number),
@setup(z = digamma+ β)),
(
log(x) + z - digamma(α),
log1p(-x) + z - digamma(β),
- 1) / x + (1 - β) / (1 - x),
),
)

ChainRulesCore.@scalar_rule(
binomlogpdf(n::Real, p::Real, k::Real),
@setup(z = digamma(n - k + 1)),
(
digamma(n + 2) - z + log1p(-p) - 1 / (1 + n),
(k / p - n) / (1 - p),
z - digamma(k + 1) + logit(p),
),
)

ChainRulesCore.@scalar_rule(
chisqlogpdf(k::Real, x::Number),
@setup(hk = k / 2),
(
(log(x) - logtwo - digamma(hk)) / 2,
(hk - 1) / x - one(hk) / 2,
),
)

ChainRulesCore.@scalar_rule(
fdistlogpdf(ν1::Real, ν2::Real, x::Number),
@setup(
xν1 = x * ν1,
temp1 = xν1 + ν2,
a = (x - 1) / temp1,
νsum = ν1 + ν2,
di = digamma(νsum / 2),
),
(
(-log1p(ν2 / xν1) - ν2 * a + di - digamma(ν1 / 2)) / 2,
(-log1p(xν1 / ν2) + ν1 * a + di - digamma(ν2 / 2)) / 2,
((ν1 - 2) / x - ν1 * νsum / temp1) / 2,
),
)

ChainRulesCore.@scalar_rule(
gammalogpdf(k::Real, θ::Real, x::Number),
@setup(
invθ = inv(θ),
xoθ = invθ * x,
z = xoθ - k,
),
(
log(xoθ) - digamma(k),
invθ * z,
- (1 + z) / x,
),
)

ChainRulesCore.@scalar_rule(
poislogpdf::Number, x::Number),
((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, log(λ) - digamma(x + 1)),
)

ChainRulesCore.@scalar_rule(
tdistlogpdf::Real, x::Number),
@setup(
νp1 = ν + 1,
xsq = x^2,
invν = inv(ν),
a = xsq * invν,
b = νp1 /+ xsq),
),
(
(digamma(νp1 / 2) - digamma/ 2) + a * b - log1p(a) - invν) / 2,
- x * b,
),
)
2 changes: 1 addition & 1 deletion src/distrs/chisq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ end
# logpdf for numbers with generic types
function chisqlogpdf(k::Real, x::Number)
hk = k / 2 # half k
-hk * log(oftype(hk, 2)) - loggamma(hk) + (hk - 1) * log(x) - x / 2
-hk * logtwo - loggamma(hk) + (hk - 1) * log(x) - x / 2
end
2 changes: 1 addition & 1 deletion src/distrs/pois.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ function poislogpdf(λ::Union{Float32,Float64}, x::Union{Float64,Float32,Integer
else
-lstirling_asym(x + 1)
=#
=#
2 changes: 1 addition & 1 deletion src/distrs/tdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ import .RFunctions:
tdistpdf::Real, x::Number) = gamma((ν + 1) / 2) / (sqrt* pi) * gamma/ 2)) * (1 + x^2 / ν)^(-+ 1) / 2)

# logpdf for numbers with generic types
tdistlogpdf::Real, x::Number) = loggamma((ν + 1) / 2) - log* pi) / 2 - loggamma/ 2) + (-+ 1) / 2) * log(1 + x^2 / ν)
tdistlogpdf::Real, x::Number) = loggamma((ν + 1) / 2) - log* pi) / 2 - loggamma/ 2) + (-+ 1) / 2) * log1p(x^2 / ν)
56 changes: 56 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using StatsFuns, Test
using ChainRulesCore
using ChainRulesTestUtils
using Random

@testset "chainrules" begin
x = exp(randn())
y = exp(randn())
z = logistic(randn())
test_frule(betalogpdf, x, y, z)
test_rrule(betalogpdf, x, y, z)

x = exp(randn())
y = exp(randn())
z = exp(randn())
test_frule(gammalogpdf, x, y, z)
test_rrule(gammalogpdf, x, y, z)

x = exp(randn())
y = exp(randn())
test_frule(chisqlogpdf, x, y)
test_rrule(chisqlogpdf, x, y)

x = exp(randn())
y = exp(randn())
z = exp(randn())
test_frule(fdistlogpdf, x, y, z)
test_rrule(fdistlogpdf, x, y, z)

x = exp(randn())
y = randn()
test_frule(tdistlogpdf, x, y)
test_rrule(tdistlogpdf, x, y)

# use `BigFloat` to avoid Rmath implementation in finite differencing check
# (returns `NaN` for non-integer values)
n = rand(1:100)
x = BigFloat(n)
y = big(logistic(randn()))
z = BigFloat(rand(1:n))
test_frule(binomlogpdf, x, y, z)
test_rrule(binomlogpdf, x, y, z)

x = big(exp(randn()))
y = BigFloat(rand(1:100))
test_frule(poislogpdf, x, y)
test_rrule(poislogpdf, x, y)

# test special case λ = 0
_, pb = rrule(StatsFuns.poislogpdf, 0.0, 0.0)
_, x̄1, _ = pb(1)
@test x̄1 == -1
_, pb = rrule(StatsFuns.poislogpdf, 0.0, 1.0)
_, x̄1, _ = pb(1)
@test x̄1 == Inf
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
tests = ["rmath", "generic", "misc"]
tests = ["rmath", "generic", "misc", "chainrules"]

for t in tests
fp = "$t.jl"
Expand Down

0 comments on commit 3aced77

Please sign in to comment.