Skip to content

Shape issue when prior predictive sampling from ZeroInflatedPoisson #3310

Closed
@twiecki

Description

@twiecki
with pm.Model() as model:
    θ = pm.Beta('θ', alpha=1, beta=1)
    ψ = pm.HalfNormal('ψ', sd=1)
    s = pm.ZeroInflatedPoisson('suppliers', psi=ψ, theta=θ, 
                               shape=20)
    gen_data = pm.sample_prior_predictive(samples=5000)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-24-aef3a9b7b63a> in <module>()
      4     s = pm.ZeroInflatedPoisson('suppliers', psi=ψ, theta=θ, 
      5                                shape=20)
----> 6     gen_data = pm.sample_prior_predictive(samples=5000)

~/working/projects/pymc/pymc3/sampling.py in sample_prior_predictive(samples, model, vars, random_seed)
   1323     names = get_default_varnames(model.named_vars, include_transformed=False)
   1324     # draw_values fails with auto-transformed variables. transform them later!
-> 1325     values = draw_values([model[name] for name in names], size=samples)
   1326 
   1327     data = {k: v for k, v in zip(names, values)}

~/working/projects/pymc/pymc3/distributions/distribution.py in draw_values(params, point, size)
    374                     # the stack of nodes to try to draw from. We exclude the
    375                     # nodes in the `params` list.
--> 376                     stack.extend([node for node in named_nodes_parents[next_]
    377                                   if node is not None and
    378                                   node.name not in drawn and

~/working/projects/pymc/pymc3/distributions/distribution.py in _draw_value(param, point, givens, size)
    468                 # shape inspection for ObservedRV
    469                 dist_tmp = param.distribution
--> 470                 try:
    471                     distshape = param.observations.shape.eval()
    472                 except AttributeError:

~/working/projects/pymc/pymc3/model.py in __call__(self, *args, **kwargs)
     41 
     42     def __call__(self, *args, **kwargs):
---> 43         return getattr(self.obj, self.method_name)(*args, **kwargs)
     44 
     45 

~/working/projects/pymc/pymc3/distributions/discrete.py in random(self, point, size)
    848                              dist_shape=self.shape,
    849                              size=size)
--> 850         return g * (np.random.random(np.squeeze(g.shape)) < psi)
    851 
    852     def logp(self, value):

ValueError: operands could not be broadcast together with shapes (5000,20) (5000,) 

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions