Skip to content

Commit

Permalink
AD workaround for nbinomlogpdf (#664)
Browse files Browse the repository at this point in the history
* ad workaround for nbinomlogpdf

* fix nbinomlogpdf derivatives

* add negative binomial logpdf AD tests

* add a test for the full gradient of nbinomlogpdf
  • Loading branch information
mohamed82008 authored and yebai committed Feb 3, 2019
1 parent bf8033a commit 21c164a
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 0 deletions.
45 changes: 45 additions & 0 deletions src/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,28 @@ Tracker.@grad function binomlogpdf(n::Int, p::Tracker.TrackedReal, x::Int)
Δ->(nothing, Δ * (x / p - (n - x) / (1 - p)), nothing)
end

import StatsFuns: nbinomlogpdf
# Note the definition of NegativeBinomial in Julia is not the same as Wikipedia's.
# Check the docstring of NegativeBinomial, r is the number of successes and
# k is the number of failures
_nbinomlogpdf_grad_1(r, p, k) = sum(1 / (k + r - i) for i in 1:k) + log(p)
_nbinomlogpdf_grad_2(r, p, k) = -k / (1 - p) + r / p

nbinomlogpdf(n::Tracker.TrackedReal, p::Tracker.TrackedReal, x::Int) = Tracker.track(nbinomlogpdf, n, p, x)
nbinomlogpdf(n::Real, p::Tracker.TrackedReal, x::Int) = Tracker.track(nbinomlogpdf, n, p, x)
nbinomlogpdf(n::Tracker.TrackedReal, p::Real, x::Int) = Tracker.track(nbinomlogpdf, n, p, x)
Tracker.@grad function nbinomlogpdf(r::Tracker.TrackedReal, p::Tracker.TrackedReal, k::Int)
return nbinomlogpdf(Tracker.data(r), Tracker.data(p), k),
Δ->* _nbinomlogpdf_grad_1(r, p, k), Δ * _nbinomlogpdf_grad_2(r, p, k), nothing)
end
Tracker.@grad function nbinomlogpdf(r::Real, p::Tracker.TrackedReal, k::Int)
return nbinomlogpdf(Tracker.data(r), Tracker.data(p), k),
Δ->(Tracker._zero(r), Δ * _nbinomlogpdf_grad_2(r, p, k), nothing)
end
Tracker.@grad function nbinomlogpdf(r::Tracker.TrackedReal, p::Real, k::Int)
return nbinomlogpdf(Tracker.data(r), Tracker.data(p), k),
Δ->* _nbinomlogpdf_grad_1(r, p, k), Tracker._zero(p), nothing)
end

import StatsFuns: poislogpdf
poislogpdf(v::Tracker.TrackedReal, x::Int) = Tracker.track(poislogpdf, v, x)
Expand All @@ -195,6 +217,29 @@ function binomlogpdf(n::Int, p::ForwardDiff.Dual{T}, x::Int) where {T}
return FD(binomlogpdf(n, val, x), Δ * (x / val - (n - x) / (1 - val)))
end

function nbinomlogpdf(r::ForwardDiff.Dual{T}, p::ForwardDiff.Dual{T}, k::Int) where {T}
FD = ForwardDiff.Dual{T}
val_p = ForwardDiff.value(p)
val_r = ForwardDiff.value(r)

Δ_r = ForwardDiff.partials(r) * _nbinomlogpdf_grad_1(val_r, val_p, k)
Δ_p = ForwardDiff.partials(p) * _nbinomlogpdf_grad_2(val_r, val_p, k)
Δ = Δ_p + Δ_r
return FD(nbinomlogpdf(val_r, val_p, k), Δ)
end
function nbinomlogpdf(r::Real, p::ForwardDiff.Dual{T}, k::Int) where {T}
FD = ForwardDiff.Dual{T}
val_p = ForwardDiff.value(p)
Δ_p = ForwardDiff.partials(p) * _nbinomlogpdf_grad_2(r, val_p, k)
return FD(nbinomlogpdf(r, val_p, k), Δ_p)
end
function nbinomlogpdf(r::ForwardDiff.Dual{T}, p::Real, k::Int) where {T}
FD = ForwardDiff.Dual{T}
val_r = ForwardDiff.value(r)
Δ_r = ForwardDiff.partials(r) * _nbinomlogpdf_grad_1(val_r, p, k)
return FD(nbinomlogpdf(val_r, p, k), Δ_r)
end

function poislogpdf(v::ForwardDiff.Dual{T}, x::Int) where {T}
FD = ForwardDiff.Dual{T}
val = ForwardDiff.value(v)
Expand Down
70 changes: 70 additions & 0 deletions test/ad.jl/AD_compatibility_with_distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,73 @@ let
atol=1e-8,
)
end

let
foo = p->Turing.nbinomlogpdf(5, p, 1)
@test isapprox(
Tracker.gradient(foo, 0.5)[1],
central_fdm(5, 1)(foo, 0.5);
rtol=1e-8,
atol=1e-8,
)
@test isapprox(
Tracker.gradient(foo, 0.5)[1],
ForwardDiff.derivative(foo, 0.5);
rtol=1e-8,
atol=1e-8,
)

bar = p->logpdf(NegativeBinomial(5, p), 3)
@test isapprox(
Tracker.gradient(bar, 0.5)[1],
central_fdm(5, 1)(bar, 0.5);
rtol=1e-8,
atol=1e-8,
)
@test isapprox(
Tracker.gradient(bar, 0.5)[1],
ForwardDiff.derivative(bar, 0.5);
rtol=1e-8,
atol=1e-8,
)
end

let
foo = r->Turing.nbinomlogpdf(r, 0.5, 1)
@test isapprox(
Tracker.gradient(foo, 3.5)[1],
central_fdm(5, 1)(foo, 3.5);
rtol=1e-8,
atol=1e-8,
)
@test isapprox(
Tracker.gradient(foo, 3.5)[1],
ForwardDiff.derivative(foo, 3.5);
rtol=1e-8,
atol=1e-8,
)

bar = r->logpdf(NegativeBinomial(r, 0.5), 3)
@test isapprox(
Tracker.gradient(bar, 3.5)[1],
central_fdm(5, 1)(bar, 3.5);
rtol=1e-8,
atol=1e-8,
)
@test isapprox(
Tracker.gradient(bar, 3.5)[1],
ForwardDiff.derivative(bar, 3.5);
rtol=1e-8,
atol=1e-8,
)
end

let
foo = x -> Turing.nbinomlogpdf(x[1], x[2], 1)
@test isapprox(
Tracker.gradient(foo, [3.5, 0.5])[1],
ForwardDiff.gradient(foo, [3.5, 0.5]);
rtol=1e-8,
atol=1e-8,
)
end

0 comments on commit 21c164a

Please sign in to comment.