Skip to content

Commit

Permalink
Refactor ZeroInflatedNegativeBinomial
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed May 11, 2021
1 parent 1fcbbb6 commit f3286d8
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 63 deletions.
97 changes: 41 additions & 56 deletions pymc3/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ def NegBinom(a, m, x):

@classmethod
def dist(cls, mu=None, alpha=None, p=None, n=None, *args, **kwargs):
n, p = cls.get_n_p(mu, alpha, p, n)
n, p = cls.get_n_p(mu=mu, alpha=alpha, p=p, n=n)
n = at.as_tensor_variable(floatX(n))
p = at.as_tensor_variable(floatX(p))
return super().dist([n, p], *args, **kwargs)
Expand Down Expand Up @@ -1481,6 +1481,21 @@ def logcdf(value, psi, n, p):
)


class ZeroInflatedNegBinomialRV(RandomVariable):
name = "zero_inflated_neg_binomial"
ndim_supp = 0
ndims_params = [0, 0, 0]
dtype = "int64"
_print_name = ("ZeroInflatedNegBinom", "\\operatorname{ZeroInflatedNegBinom}")

@classmethod
def rng_fn(cls, rng, psi, n, p, size):
return rng.negative_binomial(n=n, p=p, size=size) * (rng.random(size=size) < psi)


zero_inflated_neg_binomial = ZeroInflatedNegBinomialRV()


class ZeroInflatedNegativeBinomial(Discrete):
R"""
Zero-Inflated Negative binomial log-likelihood.
Expand Down Expand Up @@ -1550,50 +1565,17 @@ def ZeroInfNegBinom(a, m, psi, x):
"""

def __init__(self, psi, mu, alpha, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mu = mu = at.as_tensor_variable(floatX(mu))
self.alpha = alpha = at.as_tensor_variable(floatX(alpha))
self.psi = psi = at.as_tensor_variable(floatX(psi))
self.nb = NegativeBinomial.dist(mu, alpha)
self.mode = self.nb.mode
rv_op = zero_inflated_neg_binomial

def random(self, point=None, size=None):
r"""
Draw random values from ZeroInflatedNegativeBinomial distribution.
Parameters
----------
point: dict, optional
Dict of variable values on which random values are to be
conditioned (uses default point if not specified).
size: int, optional
Desired size of random sample (returns one sample if not
specified).
Returns
-------
array
"""
# mu, alpha, psi = draw_values([self.mu, self.alpha, self.psi], point=point, size=size)
# g = generate_samples(self._random, mu=mu, alpha=alpha, dist_shape=self.shape, size=size)
# g[g == 0] = np.finfo(float).eps # Just in case
# g, psi = broadcast_distribution_samples([g, psi], size=size)
# return stats.poisson.rvs(g) * (np.random.random(g.shape) < psi)

def _random(self, mu, alpha, size):
r"""Wrapper around stats.gamma.rvs that converts NegativeBinomial's
parametrization to scipy.gamma. All parameter arrays should have
been broadcasted properly by generate_samples at this point and size is
the scipy.rvs representation.
"""
return stats.gamma.rvs(
a=alpha,
scale=mu / alpha,
size=size,
)
@classmethod
def dist(cls, psi, mu, alpha, *args, **kwargs):
psi = at.as_tensor_variable(floatX(psi))
n, p = NegativeBinomial.get_n_p(mu=mu, alpha=alpha)
n = at.as_tensor_variable(floatX(n))
p = at.as_tensor_variable(floatX(p))
return super().dist([psi, n, p], *args, **kwargs)

def logp(self, value):
def logp(value, psi, n, p):
r"""
Calculate log-probability of ZeroInflatedNegativeBinomial distribution at specified value.
Expand All @@ -1607,20 +1589,22 @@ def logp(self, value):
-------
TensorVariable
"""
alpha = self.alpha
mu = self.mu
psi = self.psi

logp_other = at.log(psi) + self.nb.logp(value)
logp_0 = logaddexp(
at.log1p(-psi), at.log(psi) + alpha * (at.log(alpha) - at.log(alpha + mu))
return bound(
at.switch(
at.gt(value, 0),
at.log(psi) + NegativeBinomial.logp(value, n, p),
logaddexp(at.log1p(-psi), at.log(psi) + n * at.log(p)),
),
0 <= value,
0 <= psi,
psi <= 1,
0 < n,
0 <= p,
p <= 1,
)

logp_val = at.switch(at.gt(value, 0), logp_other, logp_0)

return bound(logp_val, 0 <= value, 0 <= psi, psi <= 1, mu > 0, alpha > 0)

def logcdf(self, value):
def logcdf(value, psi, n, p):
"""
Compute the log of the cumulative distribution function for ZeroInflatedNegativeBinomial distribution
at the specified value.
Expand All @@ -1639,13 +1623,14 @@ def logcdf(self, value):
raise TypeError(
f"ZeroInflatedNegativeBinomial.logcdf expects a scalar value but received a {np.ndim(value)}-dimensional object."
)
psi = self.psi

return bound(
logaddexp(at.log1p(-psi), at.log(psi) + self.nb.logcdf(value)),
logaddexp(at.log1p(-psi), at.log(psi) + NegativeBinomial.logcdf(value, n, p)),
0 <= value,
0 <= psi,
psi <= 1,
0 < p,
p <= 1,
)


Expand Down
34 changes: 29 additions & 5 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,8 +1660,7 @@ def logcdf_fn(value, psi, theta):
{"theta": Rplus, "psi": Unit},
)

# Too lazy to propagate decimal parameter through the whole chain of deps
@pytest.mark.xfail(reason="Distribution not refactored yet")
@pytest.mark.xfail(reason="Test not refactored yet")
@pytest.mark.xfail(
condition=(aesara.config.floatX == "float32"),
reason="Fails on float32 due to inf issues",
Expand All @@ -1673,12 +1672,37 @@ def test_zeroinflatednegativebinomial_distribution(self):
{"mu": Rplusbig, "alpha": Rplusbig, "psi": Unit},
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_zeroinflatednegativebinomial_logcdf(self):
def test_zeroinflatednegativebinomial(self):
def logp_fn(value, psi, mu, alpha):
n, p = NegativeBinomial.get_n_p(mu=mu, alpha=alpha)
if value == 0:
return np.log((1 - psi) * sp.nbinom.pmf(0, n, p))
else:
return np.log(psi * sp.nbinom.pmf(value, n, p))

def logcdf_fn(value, psi, mu, alpha):
n, p = NegativeBinomial.get_n_p(mu=mu, alpha=alpha)
return np.log((1 - psi) + psi * sp.nbinom.cdf(value, n, p))

self.check_logp(
ZeroInflatedNegativeBinomial,
Nat,
{"psi": Unit, "mu": Rplusbig, "alpha": Rplusbig},
logp_fn,
)

self.check_logcdf(
ZeroInflatedNegativeBinomial,
Nat,
{"psi": Unit, "mu": Rplusbig, "alpha": Rplusbig},
logcdf_fn,
n_samples=10,
)

self.check_selfconsistency_discrete_logcdf(
ZeroInflatedNegativeBinomial,
Nat,
{"mu": Rplusbig, "alpha": Rplusbig, "psi": Unit},
{"psi": Unit, "mu": Rplusbig, "alpha": Rplusbig},
n_samples=10,
)

Expand Down
39 changes: 37 additions & 2 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ def seeded_zero_inflated_poisson_rng_fn(self):


class TestZeroInflatedBinomial(BaseTestDistribution):
def zero_inflated_poisson_rng_fn(self, size, psi, n, p, binomial_rng_fct, random_rng_fct):
def zero_inflated_binomial_rng_fn(self, size, psi, n, p, binomial_rng_fct, random_rng_fct):
return binomial_rng_fct(n, p, size=size) * (random_rng_fct(size=size) < psi)

def seeded_zero_inflated_binomial_rng_fn(self):
Expand All @@ -967,7 +967,7 @@ def seeded_zero_inflated_binomial_rng_fn(self):
)

return functools.partial(
self.zero_inflated_poisson_rng_fn,
self.zero_inflated_binomial_rng_fn,
binomial_rng_fct=binomial_rng_fct,
random_rng_fct=random_rng_fct,
)
Expand All @@ -984,6 +984,41 @@ def seeded_zero_inflated_binomial_rng_fn(self):
]


class TestZeroInflatedNegativeBinomial(BaseTestDistribution):
def zero_inflated_negbinomial_rng_fn(
self, size, psi, n, p, negbinomial_rng_fct, random_rng_fct
):
return negbinomial_rng_fct(n, p, size=size) * (random_rng_fct(size=size) < psi)

def seeded_zero_inflated_negbinomial_rng_fn(self):
negbinomial_rng_fct = functools.partial(
getattr(np.random.RandomState, "negative_binomial"), self.get_random_state()
)

random_rng_fct = functools.partial(
getattr(np.random.RandomState, "random"), self.get_random_state()
)

return functools.partial(
self.zero_inflated_negbinomial_rng_fn,
negbinomial_rng_fct=negbinomial_rng_fct,
random_rng_fct=random_rng_fct,
)

n, p = pm.NegativeBinomial.get_n_p(mu=3, alpha=5)

pymc_dist = pm.ZeroInflatedNegativeBinomial
pymc_dist_params = {"psi": 0.9, "mu": 3, "alpha": 5}
expected_rv_op_params = {"psi": 0.9, "n": n, "p": p}
reference_dist_params = {"psi": 0.9, "n": n, "p": p}
reference_dist = seeded_zero_inflated_negbinomial_rng_fn
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
"check_rv_size",
]


class TestOrderedLogistic(BaseTestDistribution):
pymc_dist = pm.OrderedLogistic
pymc_dist_params = {"eta": 0, "cutpoints": np.array([-2, 0, 2])}
Expand Down

0 comments on commit f3286d8

Please sign in to comment.