Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MixtureSameFamily distribution #4180

Merged
merged 5 commits into from
Oct 26, 2020
Merged

Conversation

Sayam753
Copy link
Member

Co-authored-by: lucianopaz luciano.paz.neuro@gmail.com

Thank your for opening a PR!

Depending on what your PR does, here are a few things you might want to address in the description:

  • what are the (breaking) changes that this PR makes?
    The PyMC3's Mixture distribution considers the mixture components over the last dimension. This makes it difficult to handle mixtures over multivariate distributions because the last dimension is event_shape and thus, cannot represent mixture components. This PR adds a new MixtureSameFamily distribution, making it easy to specify mixtures for multivariate distributions along a specified mixture_axis. All vectorized calculations for logp and random methods, thereby making it much faster.

  • important background, or details about the implementation
    Multivariate distribution is passed as an argument and has a shape (batch_shape, mixture_axis, event_shape). For computing logp and random samples, the mixture_axis is reduced / marginalised away. This is similar to MixtureSameFamily distribution in TFP.

  • are the changes—especially new features—covered by tests and docstrings?
    Yes. I have added a few tests with mixture of Multinomial distribution. I have been trying to add mixture over MvNormal distribution as tests. But I could not do so because it is difficult to form a vectorized MvNormal distribution with a given batch_shape. I need help writing them.

  • right before it's ready to merge, mention the PR in the RELEASE-NOTES.md
    Once tests pass, I will give a mention in the RELEASE-NOTES.md

Sayam753 and others added 2 commits October 22, 2020 17:46
Co-authored-by: lucianopaz <luciano.paz.neuro@gmail.com>
@twiecki twiecki requested a review from lucianopaz October 22, 2020 13:29
@lucianopaz
Copy link
Contributor

Great @Sayam753! I'll review tomorrow

@codecov
Copy link

codecov bot commented Oct 22, 2020

Codecov Report

Merging #4180 into master will increase coverage by 0.03%.
The diff coverage is 92.20%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #4180      +/-   ##
==========================================
+ Coverage   88.87%   88.90%   +0.03%     
==========================================
  Files          89       89              
  Lines       14343    14422      +79     
==========================================
+ Hits        12747    12822      +75     
- Misses       1596     1600       +4     
Impacted Files Coverage Δ
pymc3/distributions/mixture.py 89.34% <92.10%> (+0.86%) ⬆️
pymc3/distributions/__init__.py 100.00% <100.00%> (ø)
pymc3/distributions/multivariate.py 81.20% <0.00%> (+0.10%) ⬆️
pymc3/distributions/continuous.py 93.16% <0.00%> (+0.23%) ⬆️

Copy link
Contributor

@AlexAndorra AlexAndorra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Sayam753, this looks really nice!
I'll let Luciano approve / request changes as he's the specialist here, but I just left a few suggestions and questions below. Tell me if anything's unclear 😉


class MixtureSameFamily(Distribution):
R"""
Mixture log-likelihood
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we quickly explain when this distribution should be used, compared to the existing Mixture?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agree with @AlexAndorra here. We should write in the docstring that this distribution is needed to be able to handle mixtures of multivariate distributions in a vectorized manner.

w >= 0 and w <= 1
the mixture weights
comp_dists: multidimensional PyMC3 distribution (e.g. `pm.Multinomial.dist(...)`)
with shape (batch_shape, mixture_axis, event_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we quickly explain what batch_shape and event_shape mean? It'll surely be useful to new users/users not familiar with TFP terminology

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, pymc3 makes no distinction between batch_shape and event_shape. Moreover, pymc3 doesn't consistently reduce the event_shape dimensions in the logp calculations because it just completely reduces all axis in the model's logp. This prevents pymc3 from being able to vectorize its step functions across chains, as TFP does.
For someone that doesn't know TFP or didn't have anything to do with pymc4, the batch and event shape will be completely foreign concepts. I don't know if we should explain what they are, we should simply emphasize that the shape of the component distribution passed to the MixtureSameFamily has an axis that the user can identify as the mixture_axis. We could write it down like (i_0, ..., i_n, mixture_axis, i_n+1, ..., i_N). And we should also say that the mixture_axis will be "consumed" by the mixture distribution, so the user should be aware that its mixture will end up with a shape like (i_0, ..., i_n, i_n+1, ..., i_N).

Axis to be reduced in the mixture
"""

def __init__(self, w, comp_dists, mixture_axis=-1, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does mixture_axis=-1 mean the default is to reduce along the last axis? If yes, it'd be nice to add this precision in the doc string just above

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should say this in the docstring.

self.mixture_axis = mixture_axis
kwargs.setdefault("dtype", self.comp_dists.dtype)

# Computvalte the mode so we don't always have to pass a tes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing this sentence is:

Suggested change
# Computvalte the mode so we don't always have to pass a tes
# Compute the mode so we don't always have to pass a test:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It actually was supposed to say # Compute the mode so we don't always have to pass a testval but the mouse cursor seems to have betrayed us.

Comment on lines 691 to 693
# self.w.shape (batch_shape, mixture_axis)
# self.comp_dists.shape (batch_shape, mixture_axis, event_shape)
# value.shape (batch_shape, event_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these lines be deleted?

Copy link
Contributor

@lucianopaz lucianopaz Oct 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe yes. They were just there to help debugging shape problems

Comment on lines 707 to 711
# Second, we have to add the mixture_axis to the value tensor
# To insert the mixture axis at the correct location, we use the
# negative number index. This way, we can also handle situations
# in which, value is an observed value with more batch dimensions
# than the ones present in the comp_dists
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Second, we have to add the mixture_axis to the value tensor
# To insert the mixture axis at the correct location, we use the
# negative number index. This way, we can also handle situations
# in which, value is an observed value with more batch dimensions
# than the ones present in the comp_dists
# Second, we have to add the `mixture_axis` to the `value` tensor.
# To insert the mixture axis at the correct location, we use the
# negative number index. This way, we can also handle situations
# in which `value` is an observed value with more batch dimensions
# than the ones present in the `comp_dists`.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that we should add markdown monospace delimiters in the comments. They wont be rendered anywhere. They're just there to help other developers to understand what's going on.

# First we draw values for the mixture component weights
(w,) = draw_values([self.w], point=point, size=size)

# We now draw do random choices from those weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# We now draw do random choices from those weights
# We now draw random choices from those weights.


# We now draw do random choices from those weights
# However, we have to ensure that the number of choices has the
# sample_shape prepent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "prepent" a word?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. It was supposed to be prepend.

Comment on lines 770 to 774
# To be able to take the choices along the mixture_axis of the
# comp_samples, we have to add in dimensions to the right the
# choices array
# We also need to make sure that the batch_shapes of both the comp_samples
# and choices broadcast with each other
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# To be able to take the choices along the mixture_axis of the
# comp_samples, we have to add in dimensions to the right the
# choices array
# We also need to make sure that the batch_shapes of both the comp_samples
# and choices broadcast with each other
# To be able to take the choices along the `mixture_axis` of the
# `comp_samples`, we have to add in dimensions to the right of the
# `choices` array.
# We also need to make sure that the batch_shapes of both the `comp_samples`
# and `choices` broadcast with each other.

comp_samples, choices, axis=mixture_axis - len(self.comp_dists.shape)
)

# The samples array still has the mixture_axis, so we must remove it
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# The samples array still has the mixture_axis, so we must remove it
# The `samples` array still has the `mixture_axis`, so we must remove it:

Copy link
Contributor

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks very good @Sayam753. The comments that Alex left are super helpful and you should add the things he suggested before merging this PR. Are you also trying to get a test using the MvNormal?

)
self.comp_dists = comp_dists
if mixture_axis < 0:
mixture_axis = len(comp_dists.shape) + mixture_axis
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to ensure that after we do this, the mixture_axis is positive and that it is in the comp_dists number of dimensions.

tt.argmax(_w, keepdims=True),
axis=mixture_axis,
)
self.mode = mode[(..., 0) + (slice(None),) * len(event_shape)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was kind of lazy and didn't add the self.mean. It's only necessary if the components are continuous distributions. If you want, you can try to copy the implementation from the regular Mixture here.

@junpenglao
Copy link
Member

A general comment:
Would it make sense to add the mixture_axis to pm.Mixture instead? pm.Mixture already support distributions from a same family, so this PR is a more precise control on how the reduce_sum is done on component. I suggest to explore adding mixture_axis=None to the kwarg of pm.Mixture as an alternative.

@Sayam753
Copy link
Member Author

Yes @junpenglao . I am also in favour of adding mixture_axis to Mixture distribution.
It would need careful broadcasting here while computing logp and same while drawing samples.

@AlexAndorra @lucianopaz your thoughts on this?

@twiecki
Copy link
Member

twiecki commented Oct 23, 2020

I agree, but let's leave that for a different PR.

@lucianopaz
Copy link
Contributor

My personal opinion is that the current Mixture class is a maintenance nightmare. There are multiple branches in the initialization, logp and (worst of all) random methods. I think that it should be heavily refractored. The alternative, which I'm in favor of, is to add two new distributions that are simpler and more concise. We can eventually deprecate our mixture class (or convert it to a factory function) in favor of the two separate classes. This pr just adds the same family version. Another pr could deal with the other case (the components are specified as a list of distributions).

@junpenglao
Copy link
Member

junpenglao commented Oct 23, 2020

@lucianopaz I like your approach. Could you add a bug with the steps so we are clear on the road map?

@lucianopaz
Copy link
Contributor

Sure, @junpenglao. I'll write something during the weekend

@twiecki
Copy link
Member

twiecki commented Oct 23, 2020

Focusing on this PR, is there anything missing before we can merge?

@Sayam753
Copy link
Member Author

Yes. I am writing tests with MvNormal distribution. Will push here soon.

@lucianopaz
Copy link
Contributor

@twiecki, nothing really big is missing. We should address Alex's comments on the docstrings before merging. Then also add a test with MvNormal

@Sayam753
Copy link
Member Author

Hi @lucianopaz
I have been going through Mixture distribution's random method implementation and these lines caught my attention. Do we need to handle similar broadcasting in MixtureSameFamily distribution as well in case observed data has more dimensions in batch_shape?

At present, this is the behaviour of MixtureSameFamily distribution -

with pm.Model() as model:
    mu = pm.Gamma("mu", 1.0, 1.0, shape=2)
    comp_dists = pm.Poisson.dist(mu)
    mix = pm.MixtureSameFamily("mix", w=np.ones(2)/2, comp_dists=comp_dists, shape=(1000, ))
    prior = pm.sample_prior_predictive(samples=10)

prior['mix'].shape  
# outputs (10, ) 

@junpenglao
Copy link
Member

Do we need to handle similar broadcasting in MixtureSameFamily distribution as well in case observed data has more dimensions in batch_shape?

What is the logp behavior in this case?

@Sayam753
Copy link
Member Author

Considering how logp is calculated for Mixture distribution here in case of observed data, it will be certainly different from logp calculations for MixtureSameFamily.

@junpenglao
Copy link
Member

I would imagine a similar handling here in MixtureSameFamily for the same reason.

@lucianopaz
Copy link
Contributor

Good catch @Sayam753! From what I can tell, the logp seems to be fine in that it always adds the mixture axis to the value tensor. However, we do have to check if the value broadcasts with self.shape.
The hard part is that we need to add some similar checks to random.
Regarding the prior predictive sampling. Those will always ignore the observed values and the shape they carry. You have to check what happens with the posterior predictive samples

@Sayam753
Copy link
Member Author

Its time to watch @lucianopaz's PyMCon talk to learn more about posterior predictive sampling. 🥳

@Sayam753
Copy link
Member Author

MvNormal is very tricky to understand.

>>> # 10 batch, 3 variate Gaussian
>>> mu = np.random.randn(10, 3)
>>> mat = np.random.randn(3, 3)
>>> cov = mat @ mat.T
>>> chol = np.linalg.cholesky(cov)
>>> dist = tfp.distributions.MultivariateNormalTriL(loc=mu, scale_tril=chol)
>>> dist.batch_shape, dist.event_shape
(TensorShape([10]), TensorShape([3]))
>>> dist.sample(10).shape
TensorShape([10, 10, 3])
>>>
>>> comp_dists = pm.MvNormal.dist(mu=mu, chol=chol, shape=(10, 3))
>>> comp_dists.random(size=(10)).shape
(10, 3)
>>> comp_dists.random(size=(20)).shape
(20, 10, 3)

Is this a broadcasting issue with MvNormal (when sample_shape equals batch_shape) or is this expected behaviour?

@junpenglao
Copy link
Member

There is a bug there I believe...

@Sayam753
Copy link
Member Author

So, should there be a new issue opened regarding this?

@twiecki
Copy link
Member

twiecki commented Oct 25, 2020

@Sayam753 Yes.

Handled broadcasting in case observed data has more batch dimentions

Written tests for MvNormal
@Sayam753
Copy link
Member Author

I have addressed all the suggestions. And broadcasting with respect to observed data is also handled in latest commit. So, how about another round of review?

@lucianopaz
Copy link
Contributor

Thanks @Sayam753. I'll review later today

Added a mention in RELEASE-NOTES.md
Copy link
Contributor

@AlexAndorra AlexAndorra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really good to me now, thanks a lot @Sayam753 !
I'll wait for Luciano's approval before merging

@lucianopaz
Copy link
Contributor

Thanks @Sayam753! I'll approve and merge because it looks like everything is in order now. I'm not sure if every condition is completely tested but we'll find out as people are able to use this distribution for themselves.

@lucianopaz lucianopaz merged commit 9373d5a into pymc-devs:master Oct 26, 2020
@Sayam753 Sayam753 deleted the mix_same branch October 27, 2020 07:44
ccaprani pushed a commit to ccaprani/pymc that referenced this pull request Oct 31, 2020
* Added mixture same distribution and its tests

Co-authored-by: lucianopaz <luciano.paz.neuro@gmail.com>

* Fixed pyupgrade error

* Fixed suggestions

* Written tests for broadcasting

Handled broadcasting in case observed data has more batch dimentions

Written tests for MvNormal

* Added MixtureSameFamily name in rst files

Added a mention in RELEASE-NOTES.md

Co-authored-by: lucianopaz <luciano.paz.neuro@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants