diff --git a/pymc3/distributions/mixture.py b/pymc3/distributions/mixture.py index a8e6f2e317..21b3a71aee 100644 --- a/pymc3/distributions/mixture.py +++ b/pymc3/distributions/mixture.py @@ -6,6 +6,8 @@ from .distribution import Discrete, Distribution, draw_values, generate_samples from .continuous import get_tau_sd, Normal +__all__ = ['Mixture', 'NormalMixture'] + def all_discrete(comp_dists): """ @@ -167,3 +169,26 @@ def __init__(self, w, mu, *args, **kwargs): super(NormalMixture, self).__init__(w, Normal.dist(mu, sd=sd), *args, **kwargs) + + +class Zero(Discrete): + def __init__(self, *args, **kwargs): + super(Zero, self).__init__(*args, **kwargs) + + def logp(self, value): + return tt.switch(tt.eq(value, 0), 0., -np.inf) + + def random(self, point=None, size=None, repeat=None): + def _random(dtype=self.dtype, size=None): + return np.full(size, fill_value=0, dtype=dtype) + + return generate_samples(_random, dist_shape=self.shape, + size=size).astype(self.dtype) + + +class ZeroInflatedPoisson(Mixture): + def __init__(self, theta, psi, *args, **kwargs): + w = tt.stack([psi, 1 - psi]) + comp_dists = [Zero.dist(), Poisson.dist(theta)] + + super(ZeroInflatedPoisson, self).__init__(w, comp_dists, *args, **kwargs)