Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand models supported by automatic marginalization #300

Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jan 28, 2024

This PR allows more kinds of graphs to be marginalized. Previously, we were limiting it to Elemwise operations to ensure that information across batch dimensions was not mixed between the marginalized and dependent RVs, so as to generate an efficient logp expression that did not grow with number of batch dimensions, but was constant on the domain of the marginalized variables.

We still have the same constraint, but can now analyze it more carefully by propagating information about the batch dimensions of the marginalized RV across the intermediate operations. This allows operations like DimShuffle (transposition, expand_dims, squeeze), Blockwise, Reductions and, rather important, flavors of basic and advanced indexing like the following:

import pymc as pm
import pytensor.tensor as pt

from pymc_experimental import MarginalModel

with MarginalModel() as m:
    state = pm.Categorical("state", p=[0.1, 0.3, 0.6], shape=(4,))
    # Advanced indexing was not supported before
    # The indexed variable could be an RV as well!
    mu = pt.as_tensor([-10, 0, 10])[state]
    sigma = pm.HalfNormal("sigma")
    emission = pm.Normal("emission", mu, sigma, observed=[-9.0, -0.5, 1.0, 11.0])
    
m.marginalize(state)
m.point_logps()
# {'sigma': -0.73, 'emission': -10.52}

This should expand the range of models supported and open room for further expansions.

TODO:

  • Give more informative error for explicit broadcastable dims
  • Document internal logic more carefully
  • Allow dependent multivariate RVs (it should almost work out of the box, just need some tweaks and tests)
  • Test subgrah_dims utility directly
  • Test MvNormal k-clusters model that showed up in Discourse some time ago
  • Test dependent RVs batch to the left restriction
  • Allow a couple more simple Ops

@ricardoV94 ricardoV94 added the enhancements New feature or request label Jan 28, 2024
Copy link
Contributor

@zaxtax zaxtax left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if there was some intuition for why some programs can be marginalized and some cannot. The distinction is a bit and for some cases it feels like we should be able to.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Feb 1, 2024

It would be nice if there was some intuition for why some programs can be marginalized and some cannot. The distinction is a bit and for some cases it feels like we should be able to.

You can marginalize models as long as they don't mix dimensions of the marginal RV (and as long as we can be sure of that).

So if idx is the variable you are marginalizing, a direct dependent variable could use as a parameter idx + idx or idx.T + idx.T, but not idx + idx.T because it mixes distinct dimensions. Similarly, something like sum(idx) is not allowed.

@ricardoV94
Copy link
Member Author

This PR now depends on pymc-devs/pymc#7159

@ricardoV94 ricardoV94 force-pushed the extend_automatic_marginalization branch from b0d0bce to ffb82c3 Compare February 16, 2024 18:05
@zaxtax
Copy link
Contributor

zaxtax commented Jul 26, 2024

Aside from being re-based, are there any blockers to merging this?

@twiecki
Copy link
Member

twiecki commented Jul 26, 2024

Aside from being re-based, are there any blockers to merging this?

@ricardoV94 just had a baby so probably won't respond. Do you want to rebase and we can just merge for now?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jul 26, 2024

Nah I can still type now and then. PR was not yet done

@zaxtax
Copy link
Contributor

zaxtax commented Jul 26, 2024

Congrats @ricardoV94 🥳

@ricardoV94 ricardoV94 force-pushed the extend_automatic_marginalization branch 2 times, most recently from 9341906 to 01ed4c0 Compare September 18, 2024 12:47
@ricardoV94 ricardoV94 force-pushed the extend_automatic_marginalization branch 5 times, most recently from 10e33f1 to 971ed63 Compare October 2, 2024 08:48
@ricardoV94 ricardoV94 requested a review from zaxtax October 2, 2024 08:49
@ricardoV94 ricardoV94 marked this pull request as ready for review October 2, 2024 08:50
@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 2, 2024

Failing CI is due to weird stuff between PyMC and PyTensor: pymc-devs/pytensor#1009

Hopefully it will be sorted out soon. Tests pass locally on my machine.

Open for review!

Copy link
Contributor

@zaxtax zaxtax left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really good. It's a little annoying that the refactor is mixed in with expansion of support for models, but looks pretty good.

pymc_experimental/model/marginal/distributions.py Outdated Show resolved Hide resolved
tests/model/test_marginal_model.py Show resolved Hide resolved
@ricardoV94
Copy link
Member Author

It's a little annoying that the refactor is mixed in with expansion of support for models, but looks pretty good

It's in separate commits

@zaxtax
Copy link
Contributor

zaxtax commented Oct 2, 2024

It's a little annoying that the refactor is mixed in with expansion of support for models, but looks pretty good

It's in separate commits

Yea, I reviewed it commit by commit. I saw the bulk of the work was in the first one.

@zaxtax
Copy link
Contributor

zaxtax commented Oct 2, 2024

We should probably bump the minimum version required to 5.17

@ricardoV94 ricardoV94 force-pushed the extend_automatic_marginalization branch from 971ed63 to 6099681 Compare October 3, 2024 09:15
@ricardoV94
Copy link
Member Author

We should probably bump the minimum version required to 5.17

Done

@ricardoV94 ricardoV94 force-pushed the extend_automatic_marginalization branch from 6099681 to 7da7229 Compare October 3, 2024 09:42
pytest is configured with the same behavior globally
This commit lifts the restriction that only Elemwise operations may link marginalized to dependent RVs. We map input dims to output dims, to assess whether an operation mixes information from different dims or not. Graphs where information is not mixed can be efficiently marginalized.
@ricardoV94 ricardoV94 force-pushed the extend_automatic_marginalization branch from 7da7229 to fec9eeb Compare October 4, 2024 08:25
Copy link
Contributor

@zaxtax zaxtax left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@ricardoV94 ricardoV94 merged commit e96d07f into pymc-devs:main Oct 4, 2024
7 checks passed
@ricardoV94 ricardoV94 deleted the extend_automatic_marginalization branch October 4, 2024 13:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements New feature or request marginalization
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants