Skip to content

Commit

Permalink
Modify NegativeBinomial threshold for Poisson approximation based on …
Browse files Browse the repository at this point in the history
…float accuracy, and add tests.
  • Loading branch information
ricardoV94 committed Mar 19, 2021
1 parent a4e79b1 commit 2f6c0a9
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
10 changes: 9 additions & 1 deletion pymc3/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# limitations under the License.
import warnings

import aesara
import aesara.tensor as aet
import numpy as np

from aesara.tensor.random.basic import bernoulli, binomial, categorical, nbinom, poisson
from scipy import stats

from pymc3.aesaraf import floatX, intX, take_along_axis
from pymc3.distributions import logpt
from pymc3.distributions.dist_math import (
betaln,
binomln,
Expand Down Expand Up @@ -782,7 +784,13 @@ def logp(value, n, p):
)

# Return Poisson when alpha gets very large.
return aet.switch(aet.gt(alpha, 1e10), Poisson.dist(mu).logp(value), negbinom)
threshold = 1e10 if aesara.config.floatX == "float64" else 1e7

return aet.switch(
aet.gt(alpha, threshold),
logpt(Poisson.dist(mu), value),
negbinom,
)

def logcdf(value, n, p):
"""
Expand Down
14 changes: 12 additions & 2 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,11 +1140,13 @@ def modified_scipy_hypergeom_logcdf(value, N, k, n):
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
)

@pytest.mark.xfail(reason="Logp method not refactored yet")
def test_negative_binomial(self):
def scipy_mu_alpha_logpmf(value, mu, alpha):
return sp.nbinom.logpmf(value, alpha, 1 - mu / (mu + alpha))

def scipy_poisson_approximation(value, mu, alpha):
return sp.poisson.logpmf(value, mu)

def scipy_mu_alpha_logcdf(value, mu, alpha):
return sp.nbinom.logcdf(value, alpha, 1 - mu / (mu + alpha))

Expand All @@ -1154,6 +1156,15 @@ def scipy_mu_alpha_logcdf(value, mu, alpha):
{"mu": Rplus, "alpha": Rplus},
scipy_mu_alpha_logpmf,
)
# Test values around the point where the logp of the NegativeBinomial
# is approximated with a Poisson logp
self.check_logp(
NegativeBinomial,
Nat,
{"mu": Rplus, "alpha": Domain([0, 1.5e6, 1.5e7, 1.5e9, 1.5e10, inf])},
scipy_poisson_approximation,
decimal=select_by_precision(float64=2, float32=1),
)
self.check_logp(
NegativeBinomial,
Nat,
Expand Down Expand Up @@ -1181,7 +1192,6 @@ def scipy_mu_alpha_logcdf(value, mu, alpha):
n_samples=10,
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
@pytest.mark.parametrize(
"mu, p, alpha, n, expected",
[
Expand Down

0 comments on commit 2f6c0a9

Please sign in to comment.