diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 7bd93a52c..91c96bef3 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -311,6 +311,7 @@ class ExpandedDistribution(Distribution): def __init__(self, base_dist, batch_shape=()): if isinstance(base_dist, ExpandedDistribution): + batch_shape = self._broadcast_shape(base_dist.batch_shape, batch_shape) base_dist = base_dist.base_dist self.base_dist = base_dist super().__init__(base_dist.batch_shape, base_dist.event_shape) diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 4cb390e6c..7ba82f470 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -263,19 +263,16 @@ def process_message(self, msg): frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size) cond_indep_stack.append(frame) expected_shape = self._get_batch_shape(cond_indep_stack) - dist_batch_shape = msg['fn'].batch_shape if msg['type'] == 'sample' else () - overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0) - trailing_shape = expected_shape[overlap_idx:] - # e.g. distribution with batch shape (1, 5) cannot be broadcast to (5, 5) - broadcast_shape = lax.broadcast_shapes(trailing_shape, dist_batch_shape) - if broadcast_shape != dist_batch_shape: - raise ValueError('Distribution batch shape = {} cannot be broadcast up to {}. ' - 'Consider using unbatched distributions.' - .format(dist_batch_shape, broadcast_shape)) - batch_shape = expected_shape[:overlap_idx] - if 'sample_shape' in msg['kwargs']: - batch_shape = lax.broadcast_shapes(msg['kwargs']['sample_shape'], batch_shape) - msg['kwargs']['sample_shape'] = batch_shape + if msg['type'] == 'sample': + dist_batch_shape = msg['fn'].batch_shape + if 'sample_shape' in msg['kwargs']: + dist_batch_shape = msg['kwargs']['sample_shape'] + dist_batch_shape + msg['kwargs']['sample_shape'] = () + overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0) + trailing_shape = expected_shape[overlap_idx:] + broadcast_shape = lax.broadcast_shapes(trailing_shape, dist_batch_shape) + batch_shape = expected_shape[:overlap_idx] + broadcast_shape + msg['fn'] = msg['fn'].expand(batch_shape) if self.size != self.subsample_size: scale = 1. if msg['scale'] is None else msg['scale'] msg['scale'] = scale * self.size / self.subsample_size