Skip to content

Commit

Permalink
Use _logp and _logcdf dispatcher in ZeroInflated* methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 authored and twiecki committed Jun 5, 2021
1 parent d91d649 commit 8ce7f76
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions pymc3/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
normal_lcdf,
)
from pymc3.distributions.distribution import Discrete
from pymc3.distributions.logp import _logcdf, _logp
from pymc3.math import log1mexp, logaddexp, logsumexp, sigmoid

__all__ = [
Expand Down Expand Up @@ -1306,7 +1307,7 @@ def logp(value, psi, theta):

logp_val = at.switch(
at.gt(value, 0),
at.log(psi) + Poisson.logp(value, theta),
at.log(psi) + _logp(poisson, value, {}, theta),
logaddexp(at.log1p(-psi), at.log(psi) - theta),
)

Expand Down Expand Up @@ -1335,7 +1336,7 @@ def logcdf(value, psi, theta):
"""

return bound(
logaddexp(at.log1p(-psi), at.log(psi) + Poisson.logcdf(value, theta)),
logaddexp(at.log1p(-psi), at.log(psi) + _logcdf(poisson, value, {}, theta)),
0 <= value,
0 <= psi,
psi <= 1,
Expand Down Expand Up @@ -1436,7 +1437,7 @@ def logp(value, psi, n, p):

logp_val = at.switch(
at.gt(value, 0),
at.log(psi) + Binomial.logp(value, n, p),
at.log(psi) + _logp(binomial, value, {}, n, p),
logaddexp(at.log1p(-psi), at.log(psi) + n * at.log1p(-p)),
)

Expand Down Expand Up @@ -1472,7 +1473,7 @@ def logcdf(value, psi, n, p):
)

return bound(
logaddexp(at.log1p(-psi), at.log(psi) + Binomial.logcdf(value, n, p)),
logaddexp(at.log1p(-psi), at.log(psi) + _logcdf(binomial, value, {}, n, p)),
0 <= value,
value <= n,
0 <= psi,
Expand Down Expand Up @@ -1594,7 +1595,7 @@ def logp(value, psi, n, p):
return bound(
at.switch(
at.gt(value, 0),
at.log(psi) + NegativeBinomial.logp(value, n, p),
at.log(psi) + _logp(nbinom, value, {}, n, p),
logaddexp(at.log1p(-psi), at.log(psi) + n * at.log(p)),
),
0 <= value,
Expand Down Expand Up @@ -1626,7 +1627,7 @@ def logcdf(value, psi, n, p):
)

return bound(
logaddexp(at.log1p(-psi), at.log(psi) + NegativeBinomial.logcdf(value, n, p)),
logaddexp(at.log1p(-psi), at.log(psi) + _logcdf(nbinom, value, {}, n, p)),
0 <= value,
0 <= psi,
psi <= 1,
Expand Down

0 comments on commit 8ce7f76

Please sign in to comment.