Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speed up posterior predictive sampling #6208

Merged
merged 8 commits into from
Oct 27, 2022

Conversation

OriolAbril
Copy link
Member

@OriolAbril OriolAbril commented Oct 11, 2022

The goal of this PR is to accelerate the dataset_to_point_list function which right now is
often the bottleneck of posterior predictive sampling. Moreover, I would also like to add some
extra flexibility on the dimensions that are considered sample dimensions.

related to #5160

Checklist

Bugfixes / New features

  • Added sample_dims argument to sample_posterior_predictive.

Docs / Maintenance

  • Improved the performance of sample_posterior_predictive when using InferenceData or Dataset as input.

@codecov
Copy link

codecov bot commented Oct 12, 2022

Codecov Report

Merging #6208 (75a032e) into main (d47dac0) will increase coverage by 0.19%.
The diff coverage is 96.77%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6208      +/-   ##
==========================================
+ Coverage   93.58%   93.77%   +0.19%     
==========================================
  Files         101      101              
  Lines       22136    22232      +96     
==========================================
+ Hits        20716    20849     +133     
+ Misses       1420     1383      -37     
Impacted Files Coverage Δ
pymc/backends/arviz.py 90.08% <87.50%> (-0.53%) ⬇️
pymc/sampling.py 82.53% <100.00%> (+0.05%) ⬆️
pymc/parallel_sampling.py 85.52% <0.00%> (-0.24%) ⬇️
pymc/backends/ndarray.py 79.27% <0.00%> (-0.19%) ⬇️
pymc/tests/distributions/test_truncated.py 99.48% <0.00%> (ø)
pymc/tests/backends/test_arviz.py 99.02% <0.00%> (+0.01%) ⬆️
pymc/distributions/discrete.py 99.25% <0.00%> (+0.03%) ⬆️
pymc/distributions/continuous.py 97.56% <0.00%> (+0.06%) ⬆️
pymc/distributions/multivariate.py 92.32% <0.00%> (+0.06%) ⬆️
... and 7 more

@OriolAbril
Copy link
Member Author

Got the proof of concept working with xarray-einstats and einops. Will write a simple xarray reshaper function to avoid the extra dependency. The reshape we need here is the simplest case supported by those.

@OriolAbril
Copy link
Member Author

Needs arviz-devs/arviz#2138 to get all tests to pass.

Copy link
Contributor

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @OriolAbril . Could you tell me where the speed up is coming from?

pymc/backends/arviz.py Outdated Show resolved Hide resolved
pymc/backends/arviz.py Outdated Show resolved Hide resolved
pymc/sampling.py Show resolved Hide resolved
pymc/util.py Show resolved Hide resolved
stacked_dict = {
vn: da.values.reshape((-1, *da.shape[num_sample_dims:])) for vn, da in ds.items()
}
points = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we could yield instead of returning the whole list at once?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with using a lazy generator approach unless the whole list is needed at once for some reason

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that works later in the code then yes! I only kept the list because the function is called _to_list. You should assume I have no idea about the format we need to interface with the aesara random drawing function.

Here would that be using a () comprehension or an explicit loop with a yield? Or either?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the () comprehensión more

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Me too

Copy link
Member

@ricardoV94 ricardoV94 Oct 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed the code downstream of this may be incompatible with generators (it asks for len and sometimes to check the first point...)

@ricardoV94
Copy link
Member

@OriolAbril can we split the speedup fix of this PR from the flexible sample_dims? IIRC we don't depend on Arviz for the first and could get it merged sooner?

@OriolAbril
Copy link
Member Author

I plan to release ArviZ on Saturday. In my case the limiting factor is not the ArviZ release but my own time availability. If I have to split the PR it will take longer. If you need that before feel free to split the PR. It doesn't matter if the functionality gets merged in 1 or 2 PRs

@ricardoV94
Copy link
Member

I plan to release ArviZ on Saturday. In my case the limiting factor is not the ArviZ release but my own time availability. If I have to split the PR it will take longer. If you need that before feel free to split the PR. It doesn't matter if the functionality gets merged in 1 or 2 PRs

Saturday should be fine. Let me know if there's anything I can help with otherwise.

@ricardoV94 ricardoV94 changed the title speed up pp sampling speed up posterior predictive sampling Oct 25, 2022
@OriolAbril
Copy link
Member Author

Addressed some of the comments but not all of them.

I switched the list to generator in dataset_to_point_list and got a few errors plus a mypy hell. I can probably update the code to use the generator but I can't do it today, it will need to be toward the end of the week or beginning of the next. I think it can also be a follow-up PR or if someone else has time it can be pushed here directly too.

I also used a tuple for the sample dims, but arviz expects a list so it is either using a list from the start or using a tuple and converting it to a list later on. I don't really care either way but for now left the list from the start.

I ran tests and mypy locally so I expect all tests to pass and the PR to be ready to merge after that.

@OriolAbril OriolAbril marked this pull request as ready for review October 25, 2022 14:21
@ricardoV94
Copy link
Member

@OriolAbril Can you add a bullet point in the top post under the Docs / Maintenance section?

@ricardoV94 ricardoV94 merged commit 570e6e8 into pymc-devs:main Oct 27, 2022
@OriolAbril OriolAbril deleted the speed_up_pp_sampling branch October 27, 2022 09:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants