Skip to content

Commit

Permalink
Use ExpandedDistribution in numpyro.plate (#616)
Browse files Browse the repository at this point in the history
* Use ExpandedDistribution in numpyro.plate

* add .expand optimization

* fix merge
  • Loading branch information
neerajprad authored Jun 8, 2020
1 parent cc6ede0 commit f734661
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
1 change: 1 addition & 0 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ class ExpandedDistribution(Distribution):

def __init__(self, base_dist, batch_shape=()):
if isinstance(base_dist, ExpandedDistribution):
batch_shape = self._broadcast_shape(base_dist.batch_shape, batch_shape)
base_dist = base_dist.base_dist
self.base_dist = base_dist
super().__init__(base_dist.batch_shape, base_dist.event_shape)
Expand Down
23 changes: 10 additions & 13 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,19 +263,16 @@ def process_message(self, msg):
frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size)
cond_indep_stack.append(frame)
expected_shape = self._get_batch_shape(cond_indep_stack)
dist_batch_shape = msg['fn'].batch_shape if msg['type'] == 'sample' else ()
overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0)
trailing_shape = expected_shape[overlap_idx:]
# e.g. distribution with batch shape (1, 5) cannot be broadcast to (5, 5)
broadcast_shape = lax.broadcast_shapes(trailing_shape, dist_batch_shape)
if broadcast_shape != dist_batch_shape:
raise ValueError('Distribution batch shape = {} cannot be broadcast up to {}. '
'Consider using unbatched distributions.'
.format(dist_batch_shape, broadcast_shape))
batch_shape = expected_shape[:overlap_idx]
if 'sample_shape' in msg['kwargs']:
batch_shape = lax.broadcast_shapes(msg['kwargs']['sample_shape'], batch_shape)
msg['kwargs']['sample_shape'] = batch_shape
if msg['type'] == 'sample':
dist_batch_shape = msg['fn'].batch_shape
if 'sample_shape' in msg['kwargs']:
dist_batch_shape = msg['kwargs']['sample_shape'] + dist_batch_shape
msg['kwargs']['sample_shape'] = ()
overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0)
trailing_shape = expected_shape[overlap_idx:]
broadcast_shape = lax.broadcast_shapes(trailing_shape, dist_batch_shape)
batch_shape = expected_shape[:overlap_idx] + broadcast_shape
msg['fn'] = msg['fn'].expand(batch_shape)
if self.size != self.subsample_size:
scale = 1. if msg['scale'] is None else msg['scale']
msg['scale'] = scale * self.size / self.subsample_size
Expand Down

0 comments on commit f734661

Please sign in to comment.