-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use ExpandedDistribution in numpyro.plate #616
Conversation
test/contrib/test_reparam.py
Outdated
@@ -49,7 +49,6 @@ def model(): | |||
|
|||
with handlers.trace() as tr: | |||
value = handlers.seed(model, 0)() | |||
assert isinstance(tr["x"]["fn"], dist.TransformedDistribution) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I will address it.
trailing_shape = expected_shape[overlap_idx:] | ||
broadcast_shape = lax.broadcast_shapes(trailing_shape, dist_batch_shape) | ||
batch_shape = expected_shape[:overlap_idx] + broadcast_shape | ||
msg['fn'] = msg['fn'].expand(batch_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also expand when sample_shape is available at process or postprocess trace handler for consistency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, maybe at the constructor of ExpandedDist, we can check if base dist is expanded dist. If so, we just need to store base_dist.base_dist
. This is convenient for inspecting the sites under multiple plate statements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also expand when sample_shape is available at process or postprocess trace handler for consistency?
sample_shape
should already incorporated. Right now, the behavior is that the distribution's intrinsic shape is sample_shape + dist.batch_shape
which needs to broadcast with all the plates. This was what we had earlier, so the behavior is kept consistent. We can instead do sample_shape + broadcasted shape from plates and batch shape
, but that can be surprising. Ideally, we shouldn't use sample_shape
with plate
to avoid this kind of confusion.
Also, maybe at the constructor of ExpandedDist, we can check if base dist is expanded dist. If so, we just need to store base_dist.base_dist.
Let me see how we are doing this in Pyro, I agree multiple layers of ExpandedDist
is not great for debugging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. It is a bit confusing. I like your change because it solves the shape broadcasting issue that we have before (the ValueError in the previous code).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding..
Also, maybe at the constructor of ExpandedDist, we can check if base dist is expanded dist. If so, we just need to store base_dist.base_dist.
Just by using nested plates, we will never hit this condition where the base_dist
is already an ExpandedDistribution
. This can only happen if the user uses .expand
and plate
together. Does that make sense? We can still introduce this optimization, but it shouldn't affect the vast majority of models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I already add this optimization in #617 (just to make sure that we can use TransformReparam
with multiple .expand
op).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me! Which performance regressions that you think about? I think those shape manipulations here and in expand
will negligible under jit
.
You'll need to merge the master branch to pass CI. |
I just wanted to see if we noticed any slowness in our test suite, but there is negligible effect, as you noted. This should be good to go once tests pass. |
@@ -313,6 +313,9 @@ def __init__(self, base_dist, batch_shape=()): | |||
if isinstance(base_dist, ExpandedDistribution): | |||
base_dist = base_dist.base_dist | |||
self.base_dist = base_dist | |||
if isinstance(base_dist, ExpandedDistribution): | |||
batch_shape = self._broadcast_shape(base_dist.batch_shape, batch_shape) | |||
base_dist = base_dist.base_dist |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This optimization is addressed above. Do you think we need extra logic here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed that during merge, I think we should still do an explicit broadcast so that. dist.expand(5, 4).expand(6, 4)
raises an error. Will fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, .expand(5, 4).expand(6, 4)
seems to be invalid, hence an error is expected. Did you mean something like: .expand(5, 1).expand(1, 4)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right!!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thanks for resolving several bugs, @neerajprad !
Fixes #615. I noticed this while reviewing #612.