Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring back SMC and allow prior_predictive_sampling to return transformed values #4769

Merged
merged 5 commits into from
Jun 16, 2021

Conversation

ricardoV94
Copy link
Member

SMC was broken after the refactoring because it starts with a prior_predictive_sampling call to set up the particles positions, expecting it to return also transformed values. I extended prior_predictive to return transformed values if (and only if) transformed variables are explicitly passed in the optional var_names argument. For reference, in v3 transformed variables were returned by default. If anyone has a strong opinion about the old default let me know!

Tests were added for this as well as for the now stale issue #4490

Depending on what your PR does, here are a few things you might want to address in the description:

@ricardoV94 ricardoV94 force-pushed the smc_v4_compat branch 2 times, most recently from ee122d0 to 57e4f5a Compare June 14, 2021 16:15
Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These failing tests look systematic. Some kind of dytpe problem..

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jun 15, 2021

These failing tests look systematic. Some kind of dytpe problem..

Definitely. One of the SMC tests is failing in float32 when there is a discrete variable:

https://github.com/pymc-devs/pymc3/blob/57e4f5a177c98a928c3e590800d1a17af4237b50/pymc3/tests/test_smc.py#L67-L72

It happens in the join_nonshared_inputs which concatenates all the raveled variables, leading to an upcast to float64 when there are both discrete and continuous variables (since discretes are int64).

https://github.com/pymc-devs/pymc3/blob/f7d460212c0539e6c7a7ae394a3f2de6416068c7/pymc3/aesaraf.py#L604-L608

The problem then is that the input can no longer be of type float32 / floatX. I can make the test pass by just wrapping joined in a pm.aesaraf.floatX, confirming the problem lies there. I don't think this is a great fix though...

This could be a problem in other areas of the codebase that make use of this function. I've seen it in metropolis.py, mlda.py and pgbart.py.

Edit: Possibly related to #4553

Edit2: Link to the failing test: https://github.com/pymc-devs/pymc3/runs/2821726372?check_suite_focus=true#step:7:599
(I canceled the last workflow, I had just pushed a rebase and some minor comments, so the tests would still fail)

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jun 15, 2021

Also, more in general, within SMC we are treating discrete variables as continuous (e.g., in the proposal). Are we comfortable with this @aloctavodia? I know that the logp methods deal fine with float inputs, but it still feels strange to feed them non-rounded values.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jun 16, 2021

I temporarily disabled trust_input to pass the tests. I tried a bunch of things (mostly with clone_replace + casting) to allow the prior_logp_func function to accept float32 inputs in models with discrete variables, without success.

Edit: I restricted this change to when absolutely needed, and issued an informative UserWarning in those cases.

@ricardoV94 ricardoV94 merged commit a90457a into pymc-devs:main Jun 16, 2021
@ricardoV94 ricardoV94 deleted the smc_v4_compat branch June 17, 2021 06:29
@aloctavodia
Copy link
Member

Sorry for being late to the party. Yes I am conformable with treating discrete variables as continuous when proposing new values.

@OriolAbril
Copy link
Member

I saw this checking the release notes in:

pm.sample_prior_predictive no longer returns transformed variable values by default. Pass them by name in var_names if you want to obtain these draws (see 4769).

The converter to InferenceData ignores transformed values by default, so I find the phrasing is a bit misleading and potentially troublesome. We should probably add an argument to the converter to include transformed variables into the inferencedata otherwise we'll need to keep the dict return and the capabilities of the function will depend on the output chosen

@ricardoV94
Copy link
Member Author

ricardoV94 commented Nov 8, 2021

The plan is to revert this, see #5076

This is no longer needed, as we can use the model.initial_point to get transformed prior predictive samples when they are needed for our samplers

@OriolAbril
Copy link
Member

OriolAbril commented Nov 8, 2021

I think I should probably go over issues and PRs in both pymc and arviz and make a serious cleanup of integration with InferenceData, but I don't think I'll have time for a while. We still have arguments that were more workarounds than actual arguments/fixes and should be removed as they are generally useless now (i.e. density_dist_obs in to_inferende_data, keep_size in sample_posterior_predictive), the transforms presence is also annoying: arviz-devs/arviz#1509, arviz-devs/arviz#230, and in general we can simplify the converter quite a bit now that it lives in the pymc codebase and should not need complicated logic to work with multiple pymc versions. I think we could also make pointwise log likeihood storage and posterior predictive sampling work with dask (as in my experience it is common that the model/posterior fits in memory but there are many observations and ll and pp do not, and arviz does support working with dask backed arrays, the main limitation right now is creating those dask backed arrays).

Maybe other improvements are also relaatively low hanging fruit?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
SMC Sequential Monte Carlo
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants