diff --git a/pymc3/distributions/discrete.py b/pymc3/distributions/discrete.py index ca9f091ad13..67cc76c106c 100644 --- a/pymc3/distributions/discrete.py +++ b/pymc3/distributions/discrete.py @@ -13,6 +13,7 @@ # limitations under the License. import warnings +import aesara import aesara.tensor as aet import numpy as np @@ -20,6 +21,7 @@ from scipy import stats from pymc3.aesaraf import floatX, intX, take_along_axis +from pymc3.distributions import logpt from pymc3.distributions.dist_math import ( betaln, binomln, @@ -782,7 +784,13 @@ def logp(value, n, p): ) # Return Poisson when alpha gets very large. - return aet.switch(aet.gt(alpha, 1e10), Poisson.dist(mu).logp(value), negbinom) + threshold = 1e10 if aesara.config.floatX == "float64" else 1e7 + + return aet.switch( + aet.gt(alpha, threshold), + logpt(Poisson.dist(mu), value), + negbinom, + ) def logcdf(value, n, p): """ diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 67a62c1fffc..652751f8a8f 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -1140,11 +1140,13 @@ def modified_scipy_hypergeom_logcdf(value, N, k, n): {"N": NatSmall, "k": NatSmall, "n": NatSmall}, ) - @pytest.mark.xfail(reason="Logp method not refactored yet") def test_negative_binomial(self): def scipy_mu_alpha_logpmf(value, mu, alpha): return sp.nbinom.logpmf(value, alpha, 1 - mu / (mu + alpha)) + def scipy_poisson_approximation(value, mu, alpha): + return sp.poisson.logpmf(value, mu) + def scipy_mu_alpha_logcdf(value, mu, alpha): return sp.nbinom.logcdf(value, alpha, 1 - mu / (mu + alpha)) @@ -1154,6 +1156,15 @@ def scipy_mu_alpha_logcdf(value, mu, alpha): {"mu": Rplus, "alpha": Rplus}, scipy_mu_alpha_logpmf, ) + # Test values around the point where the logp of the NegativeBinomial + # is approximated with a Poisson logp + self.check_logp( + NegativeBinomial, + Nat, + {"mu": Rplus, "alpha": Domain([0, 1.5e6, 1.5e7, 1.5e9, 1.5e10, inf])}, + scipy_poisson_approximation, + decimal=select_by_precision(float64=2, float32=1), + ) self.check_logp( NegativeBinomial, Nat, @@ -1181,7 +1192,6 @@ def scipy_mu_alpha_logcdf(value, mu, alpha): n_samples=10, ) - @pytest.mark.xfail(reason="Distribution not refactored yet") @pytest.mark.parametrize( "mu, p, alpha, n, expected", [