Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ExpandedDistribution in numpyro.plate #616

Merged
merged 4 commits into from
Jun 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ class ExpandedDistribution(Distribution):

def __init__(self, base_dist, batch_shape=()):
if isinstance(base_dist, ExpandedDistribution):
batch_shape = self._broadcast_shape(base_dist.batch_shape, batch_shape)
base_dist = base_dist.base_dist
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This optimization is addressed above. Do you think we need extra logic here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed that during merge, I think we should still do an explicit broadcast so that. dist.expand(5, 4).expand(6, 4) raises an error. Will fix.

Copy link
Member

@fehiepsi fehiepsi Jun 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, .expand(5, 4).expand(6, 4) seems to be invalid, hence an error is expected. Did you mean something like: .expand(5, 1).expand(1, 4)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right!!

self.base_dist = base_dist
super().__init__(base_dist.batch_shape, base_dist.event_shape)
Expand Down
23 changes: 10 additions & 13 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,19 +263,16 @@ def process_message(self, msg):
frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size)
cond_indep_stack.append(frame)
expected_shape = self._get_batch_shape(cond_indep_stack)
dist_batch_shape = msg['fn'].batch_shape if msg['type'] == 'sample' else ()
overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0)
trailing_shape = expected_shape[overlap_idx:]
# e.g. distribution with batch shape (1, 5) cannot be broadcast to (5, 5)
broadcast_shape = lax.broadcast_shapes(trailing_shape, dist_batch_shape)
if broadcast_shape != dist_batch_shape:
raise ValueError('Distribution batch shape = {} cannot be broadcast up to {}. '
'Consider using unbatched distributions.'
.format(dist_batch_shape, broadcast_shape))
batch_shape = expected_shape[:overlap_idx]
if 'sample_shape' in msg['kwargs']:
batch_shape = lax.broadcast_shapes(msg['kwargs']['sample_shape'], batch_shape)
msg['kwargs']['sample_shape'] = batch_shape
if msg['type'] == 'sample':
dist_batch_shape = msg['fn'].batch_shape
if 'sample_shape' in msg['kwargs']:
dist_batch_shape = msg['kwargs']['sample_shape'] + dist_batch_shape
msg['kwargs']['sample_shape'] = ()
overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0)
trailing_shape = expected_shape[overlap_idx:]
broadcast_shape = lax.broadcast_shapes(trailing_shape, dist_batch_shape)
batch_shape = expected_shape[:overlap_idx] + broadcast_shape
msg['fn'] = msg['fn'].expand(batch_shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also expand when sample_shape is available at process or postprocess trace handler for consistency?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, maybe at the constructor of ExpandedDist, we can check if base dist is expanded dist. If so, we just need to store base_dist.base_dist. This is convenient for inspecting the sites under multiple plate statements.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also expand when sample_shape is available at process or postprocess trace handler for consistency?

sample_shape should already incorporated. Right now, the behavior is that the distribution's intrinsic shape is sample_shape + dist.batch_shape which needs to broadcast with all the plates. This was what we had earlier, so the behavior is kept consistent. We can instead do sample_shape + broadcasted shape from plates and batch shape, but that can be surprising. Ideally, we shouldn't use sample_shape with plate to avoid this kind of confusion.

Also, maybe at the constructor of ExpandedDist, we can check if base dist is expanded dist. If so, we just need to store base_dist.base_dist.

Let me see how we are doing this in Pyro, I agree multiple layers of ExpandedDist is not great for debugging.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It is a bit confusing. I like your change because it solves the shape broadcasting issue that we have before (the ValueError in the previous code).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding..

Also, maybe at the constructor of ExpandedDist, we can check if base dist is expanded dist. If so, we just need to store base_dist.base_dist.

Just by using nested plates, we will never hit this condition where the base_dist is already an ExpandedDistribution. This can only happen if the user uses .expand and plate together. Does that make sense? We can still introduce this optimization, but it shouldn't affect the vast majority of models.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I already add this optimization in #617 (just to make sure that we can use TransformReparam with multiple .expand op).

if self.size != self.subsample_size:
scale = 1. if msg['scale'] is None else msg['scale']
msg['scale'] = scale * self.size / self.subsample_size
Expand Down