From 5789fea13d7b4e028af22f472c90b57a0c364be8 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 2 Jun 2020 19:18:19 -0700 Subject: [PATCH 1/3] Use ExpandedDistribution in numpyro.plate --- numpyro/primitives.py | 23 ++++++++++------------- test/contrib/test_reparam.py | 1 - 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 654082710..0ddf2b1ba 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -265,19 +265,16 @@ def process_message(self, msg): frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size) cond_indep_stack.append(frame) expected_shape = self._get_batch_shape(cond_indep_stack) - dist_batch_shape = msg['fn'].batch_shape if msg['type'] == 'sample' else () - overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0) - trailing_shape = expected_shape[overlap_idx:] - # e.g. distribution with batch shape (1, 5) cannot be broadcast to (5, 5) - broadcast_shape = lax.broadcast_shapes(trailing_shape, dist_batch_shape) - if broadcast_shape != dist_batch_shape: - raise ValueError('Distribution batch shape = {} cannot be broadcast up to {}. ' - 'Consider using unbatched distributions.' - .format(dist_batch_shape, broadcast_shape)) - batch_shape = expected_shape[:overlap_idx] - if 'sample_shape' in msg['kwargs']: - batch_shape = lax.broadcast_shapes(msg['kwargs']['sample_shape'], batch_shape) - msg['kwargs']['sample_shape'] = batch_shape + if msg['type'] == 'sample': + dist_batch_shape = msg['fn'].batch_shape + if 'sample_shape' in msg['kwargs']: + dist_batch_shape = msg['kwargs']['sample_shape'] + dist_batch_shape + msg['kwargs']['sample_shape'] = () + overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0) + trailing_shape = expected_shape[overlap_idx:] + broadcast_shape = lax.broadcast_shapes(trailing_shape, dist_batch_shape) + batch_shape = expected_shape[:overlap_idx] + broadcast_shape + msg['fn'] = msg['fn'].expand(batch_shape) if self.size != self.subsample_size: scale = 1. if msg['scale'] is None else msg['scale'] msg['scale'] = scale * self.size / self.subsample_size diff --git a/test/contrib/test_reparam.py b/test/contrib/test_reparam.py index 73b36dd84..3c729efa9 100644 --- a/test/contrib/test_reparam.py +++ b/test/contrib/test_reparam.py @@ -49,7 +49,6 @@ def model(): with handlers.trace() as tr: value = handlers.seed(model, 0)() - assert isinstance(tr["x"]["fn"], dist.TransformedDistribution) expected_moments = get_moments(value) with reparam(config={"x": TransformReparam()}): From 8b6e7ecf77bb451b93247c4c0584550334b6bcaf Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 8 Jun 2020 09:24:46 -0700 Subject: [PATCH 2/3] add .expand optimization --- numpyro/distributions/distribution.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index dea2231a7..e885ccbc9 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -267,6 +267,9 @@ class ExpandedDistribution(Distribution): def __init__(self, base_dist, batch_shape=()): self.base_dist = base_dist + if isinstance(base_dist, ExpandedDistribution): + batch_shape = self._broadcast_shape(base_dist.batch_shape, batch_shape) + base_dist = base_dist.base_dist super().__init__(base_dist.batch_shape, base_dist.event_shape) # adjust batch shape self.expand(batch_shape) From ce8aca34a5bc2148e2f758b5e5507a941233ff22 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 8 Jun 2020 14:27:53 -0700 Subject: [PATCH 3/3] fix merge --- numpyro/distributions/distribution.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 7cfbfd6f5..91c96bef3 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -310,12 +310,10 @@ class ExpandedDistribution(Distribution): arg_constraints = {} def __init__(self, base_dist, batch_shape=()): - if isinstance(base_dist, ExpandedDistribution): - base_dist = base_dist.base_dist - self.base_dist = base_dist if isinstance(base_dist, ExpandedDistribution): batch_shape = self._broadcast_shape(base_dist.batch_shape, batch_shape) base_dist = base_dist.base_dist + self.base_dist = base_dist super().__init__(base_dist.batch_shape, base_dist.event_shape) # adjust batch shape self.expand(batch_shape)