Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ExpandedDistribution in numpyro.plate #616

Merged
merged 4 commits into from
Jun 8, 2020
Merged

Use ExpandedDistribution in numpyro.plate #616

merged 4 commits into from
Jun 8, 2020

Conversation

neerajprad
Copy link
Member

@neerajprad neerajprad commented Jun 3, 2020

Fixes #615. I noticed this while reviewing #612.

  • Check that there is no perf regression.

@@ -49,7 +49,6 @@ def model():

with handlers.trace() as tr:
value = handlers.seed(model, 0)()
assert isinstance(tr["x"]["fn"], dist.TransformedDistribution)
Copy link
Member Author

@neerajprad neerajprad Jun 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi - The TransformedDistribution will be the base distribution (of ExpandedDistribution) now, so I am not sure if this PR can go through until we resolve #610.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I will address it.

trailing_shape = expected_shape[overlap_idx:]
broadcast_shape = lax.broadcast_shapes(trailing_shape, dist_batch_shape)
batch_shape = expected_shape[:overlap_idx] + broadcast_shape
msg['fn'] = msg['fn'].expand(batch_shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also expand when sample_shape is available at process or postprocess trace handler for consistency?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, maybe at the constructor of ExpandedDist, we can check if base dist is expanded dist. If so, we just need to store base_dist.base_dist. This is convenient for inspecting the sites under multiple plate statements.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also expand when sample_shape is available at process or postprocess trace handler for consistency?

sample_shape should already incorporated. Right now, the behavior is that the distribution's intrinsic shape is sample_shape + dist.batch_shape which needs to broadcast with all the plates. This was what we had earlier, so the behavior is kept consistent. We can instead do sample_shape + broadcasted shape from plates and batch shape, but that can be surprising. Ideally, we shouldn't use sample_shape with plate to avoid this kind of confusion.

Also, maybe at the constructor of ExpandedDist, we can check if base dist is expanded dist. If so, we just need to store base_dist.base_dist.

Let me see how we are doing this in Pyro, I agree multiple layers of ExpandedDist is not great for debugging.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It is a bit confusing. I like your change because it solves the shape broadcasting issue that we have before (the ValueError in the previous code).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding..

Also, maybe at the constructor of ExpandedDist, we can check if base dist is expanded dist. If so, we just need to store base_dist.base_dist.

Just by using nested plates, we will never hit this condition where the base_dist is already an ExpandedDistribution. This can only happen if the user uses .expand and plate together. Does that make sense? We can still introduce this optimization, but it shouldn't affect the vast majority of models.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I already add this optimization in #617 (just to make sure that we can use TransformReparam with multiple .expand op).

fehiepsi
fehiepsi previously approved these changes Jun 4, 2020
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me! Which performance regressions that you think about? I think those shape manipulations here and in expand will negligible under jit.

@fehiepsi
Copy link
Member

fehiepsi commented Jun 7, 2020

You'll need to merge the master branch to pass CI.

@neerajprad
Copy link
Member Author

Which performance regressions that you think about? I think those shape manipulations here and in expand will negligible under jit.

I just wanted to see if we noticed any slowness in our test suite, but there is negligible effect, as you noted. This should be good to go once tests pass.

@@ -313,6 +313,9 @@ def __init__(self, base_dist, batch_shape=()):
if isinstance(base_dist, ExpandedDistribution):
base_dist = base_dist.base_dist
self.base_dist = base_dist
if isinstance(base_dist, ExpandedDistribution):
batch_shape = self._broadcast_shape(base_dist.batch_shape, batch_shape)
base_dist = base_dist.base_dist
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This optimization is addressed above. Do you think we need extra logic here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed that during merge, I think we should still do an explicit broadcast so that. dist.expand(5, 4).expand(6, 4) raises an error. Will fix.

Copy link
Member

@fehiepsi fehiepsi Jun 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, .expand(5, 4).expand(6, 4) seems to be invalid, hence an error is expected. Did you mean something like: .expand(5, 1).expand(1, 4)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right!!

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks for resolving several bugs, @neerajprad !

@fehiepsi fehiepsi merged commit f734661 into master Jun 8, 2020
@fehiepsi fehiepsi deleted the expanded-dist branch June 8, 2020 22:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use ExpandedDistribution in numpyro.plate
2 participants