-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MixtureSameFamily distribution #4180
Conversation
Co-authored-by: lucianopaz <luciano.paz.neuro@gmail.com>
Great @Sayam753! I'll review tomorrow |
Codecov Report
@@ Coverage Diff @@
## master #4180 +/- ##
==========================================
+ Coverage 88.87% 88.90% +0.03%
==========================================
Files 89 89
Lines 14343 14422 +79
==========================================
+ Hits 12747 12822 +75
- Misses 1596 1600 +4
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Sayam753, this looks really nice!
I'll let Luciano approve / request changes as he's the specialist here, but I just left a few suggestions and questions below. Tell me if anything's unclear 😉
pymc3/distributions/mixture.py
Outdated
|
||
class MixtureSameFamily(Distribution): | ||
R""" | ||
Mixture log-likelihood |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we quickly explain when this distribution should be used, compared to the existing Mixture
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally agree with @AlexAndorra here. We should write in the docstring that this distribution is needed to be able to handle mixtures of multivariate distributions in a vectorized manner.
pymc3/distributions/mixture.py
Outdated
w >= 0 and w <= 1 | ||
the mixture weights | ||
comp_dists: multidimensional PyMC3 distribution (e.g. `pm.Multinomial.dist(...)`) | ||
with shape (batch_shape, mixture_axis, event_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we quickly explain what batch_shape
and event_shape
mean? It'll surely be useful to new users/users not familiar with TFP terminology
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, pymc3 makes no distinction between batch_shape
and event_shape
. Moreover, pymc3 doesn't consistently reduce the event_shape
dimensions in the logp
calculations because it just completely reduces all axis in the model's logp
. This prevents pymc3 from being able to vectorize its step functions across chains, as TFP does.
For someone that doesn't know TFP or didn't have anything to do with pymc4, the batch and event shape will be completely foreign concepts. I don't know if we should explain what they are, we should simply emphasize that the shape of the component distribution passed to the MixtureSameFamily
has an axis that the user can identify as the mixture_axis
. We could write it down like (i_0, ..., i_n, mixture_axis, i_n+1, ..., i_N)
. And we should also say that the mixture_axis
will be "consumed" by the mixture distribution, so the user should be aware that its mixture will end up with a shape like (i_0, ..., i_n, i_n+1, ..., i_N)
.
Axis to be reduced in the mixture | ||
""" | ||
|
||
def __init__(self, w, comp_dists, mixture_axis=-1, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does mixture_axis=-1
mean the default is to reduce along the last axis? If yes, it'd be nice to add this precision in the doc string just above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we should say this in the docstring.
pymc3/distributions/mixture.py
Outdated
self.mixture_axis = mixture_axis | ||
kwargs.setdefault("dtype", self.comp_dists.dtype) | ||
|
||
# Computvalte the mode so we don't always have to pass a tes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing this sentence is:
# Computvalte the mode so we don't always have to pass a tes | |
# Compute the mode so we don't always have to pass a test: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It actually was supposed to say # Compute the mode so we don't always have to pass a testval
but the mouse cursor seems to have betrayed us.
pymc3/distributions/mixture.py
Outdated
# self.w.shape (batch_shape, mixture_axis) | ||
# self.comp_dists.shape (batch_shape, mixture_axis, event_shape) | ||
# value.shape (batch_shape, event_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these lines be deleted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe yes. They were just there to help debugging shape problems
pymc3/distributions/mixture.py
Outdated
# Second, we have to add the mixture_axis to the value tensor | ||
# To insert the mixture axis at the correct location, we use the | ||
# negative number index. This way, we can also handle situations | ||
# in which, value is an observed value with more batch dimensions | ||
# than the ones present in the comp_dists |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Second, we have to add the mixture_axis to the value tensor | |
# To insert the mixture axis at the correct location, we use the | |
# negative number index. This way, we can also handle situations | |
# in which, value is an observed value with more batch dimensions | |
# than the ones present in the comp_dists | |
# Second, we have to add the `mixture_axis` to the `value` tensor. | |
# To insert the mixture axis at the correct location, we use the | |
# negative number index. This way, we can also handle situations | |
# in which `value` is an observed value with more batch dimensions | |
# than the ones present in the `comp_dists`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that we should add markdown monospace delimiters in the comments. They wont be rendered anywhere. They're just there to help other developers to understand what's going on.
pymc3/distributions/mixture.py
Outdated
# First we draw values for the mixture component weights | ||
(w,) = draw_values([self.w], point=point, size=size) | ||
|
||
# We now draw do random choices from those weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# We now draw do random choices from those weights | |
# We now draw random choices from those weights. |
pymc3/distributions/mixture.py
Outdated
|
||
# We now draw do random choices from those weights | ||
# However, we have to ensure that the number of choices has the | ||
# sample_shape prepent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is "prepent" a word?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops. It was supposed to be prepend.
pymc3/distributions/mixture.py
Outdated
# To be able to take the choices along the mixture_axis of the | ||
# comp_samples, we have to add in dimensions to the right the | ||
# choices array | ||
# We also need to make sure that the batch_shapes of both the comp_samples | ||
# and choices broadcast with each other |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# To be able to take the choices along the mixture_axis of the | |
# comp_samples, we have to add in dimensions to the right the | |
# choices array | |
# We also need to make sure that the batch_shapes of both the comp_samples | |
# and choices broadcast with each other | |
# To be able to take the choices along the `mixture_axis` of the | |
# `comp_samples`, we have to add in dimensions to the right of the | |
# `choices` array. | |
# We also need to make sure that the batch_shapes of both the `comp_samples` | |
# and `choices` broadcast with each other. |
pymc3/distributions/mixture.py
Outdated
comp_samples, choices, axis=mixture_axis - len(self.comp_dists.shape) | ||
) | ||
|
||
# The samples array still has the mixture_axis, so we must remove it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# The samples array still has the mixture_axis, so we must remove it | |
# The `samples` array still has the `mixture_axis`, so we must remove it: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks very good @Sayam753. The comments that Alex left are super helpful and you should add the things he suggested before merging this PR. Are you also trying to get a test using the MvNormal
?
) | ||
self.comp_dists = comp_dists | ||
if mixture_axis < 0: | ||
mixture_axis = len(comp_dists.shape) + mixture_axis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to ensure that after we do this, the mixture_axis
is positive and that it is in the comp_dists
number of dimensions.
tt.argmax(_w, keepdims=True), | ||
axis=mixture_axis, | ||
) | ||
self.mode = mode[(..., 0) + (slice(None),) * len(event_shape)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was kind of lazy and didn't add the self.mean
. It's only necessary if the components are continuous distributions. If you want, you can try to copy the implementation from the regular Mixture
here.
A general comment: |
Yes @junpenglao . I am also in favour of adding @AlexAndorra @lucianopaz your thoughts on this? |
I agree, but let's leave that for a different PR. |
My personal opinion is that the current |
@lucianopaz I like your approach. Could you add a bug with the steps so we are clear on the road map? |
Sure, @junpenglao. I'll write something during the weekend |
Focusing on this PR, is there anything missing before we can merge? |
Yes. I am writing tests with MvNormal distribution. Will push here soon. |
@twiecki, nothing really big is missing. We should address Alex's comments on the docstrings before merging. Then also add a test with MvNormal |
Hi @lucianopaz At present, this is the behaviour of with pm.Model() as model:
mu = pm.Gamma("mu", 1.0, 1.0, shape=2)
comp_dists = pm.Poisson.dist(mu)
mix = pm.MixtureSameFamily("mix", w=np.ones(2)/2, comp_dists=comp_dists, shape=(1000, ))
prior = pm.sample_prior_predictive(samples=10)
prior['mix'].shape
# outputs (10, ) |
What is the |
Considering how |
I would imagine a similar handling here in MixtureSameFamily for the same reason. |
Good catch @Sayam753! From what I can tell, the logp seems to be fine in that it always adds the mixture axis to the value tensor. However, we do have to check if the value broadcasts with |
Its time to watch @lucianopaz's PyMCon talk to learn more about posterior predictive sampling. 🥳 |
MvNormal is very tricky to understand.
Is this a broadcasting issue with MvNormal (when sample_shape equals batch_shape) or is this expected behaviour? |
There is a bug there I believe... |
So, should there be a new issue opened regarding this? |
@Sayam753 Yes. |
Handled broadcasting in case observed data has more batch dimentions Written tests for MvNormal
I have addressed all the suggestions. And broadcasting with respect to observed data is also handled in latest commit. So, how about another round of review? |
Thanks @Sayam753. I'll review later today |
Added a mention in RELEASE-NOTES.md
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really good to me now, thanks a lot @Sayam753 !
I'll wait for Luciano's approval before merging
Thanks @Sayam753! I'll approve and merge because it looks like everything is in order now. I'm not sure if every condition is completely tested but we'll find out as people are able to use this distribution for themselves. |
* Added mixture same distribution and its tests Co-authored-by: lucianopaz <luciano.paz.neuro@gmail.com> * Fixed pyupgrade error * Fixed suggestions * Written tests for broadcasting Handled broadcasting in case observed data has more batch dimentions Written tests for MvNormal * Added MixtureSameFamily name in rst files Added a mention in RELEASE-NOTES.md Co-authored-by: lucianopaz <luciano.paz.neuro@gmail.com>
Co-authored-by: lucianopaz luciano.paz.neuro@gmail.com
Thank your for opening a PR!
Depending on what your PR does, here are a few things you might want to address in the description:
what are the (breaking) changes that this PR makes?
The PyMC3's Mixture distribution considers the mixture components over the last dimension. This makes it difficult to handle mixtures over multivariate distributions because the last dimension is event_shape and thus, cannot represent mixture components. This PR adds a new MixtureSameFamily distribution, making it easy to specify mixtures for multivariate distributions along a specified
mixture_axis
. All vectorized calculations forlogp
andrandom
methods, thereby making it much faster.important background, or details about the implementation
Multivariate distribution is passed as an argument and has a shape (batch_shape, mixture_axis, event_shape). For computing
logp
andrandom
samples, themixture_axis
is reduced / marginalised away. This is similar toMixtureSameFamily
distribution in TFP.are the changes—especially new features—covered by tests and docstrings?
Yes. I have added a few tests with mixture of Multinomial distribution. I have been trying to add mixture over MvNormal distribution as tests. But I could not do so because it is difficult to form a vectorized MvNormal distribution with a given
batch_shape
. I need help writing them.right before it's ready to merge, mention the PR in the RELEASE-NOTES.md
Once tests pass, I will give a mention in the RELEASE-NOTES.md