Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow customizing truncation max_n_steps in Hurdle Mixtures #7339

Merged
merged 6 commits into from
May 31, 2024

Conversation

tjburch
Copy link
Contributor

@tjburch tjburch commented May 30, 2024

Description

Separates kwargs going to _hurdle_mixture by name. Gathers any which should redirect to Truncated (lower, upper, and max_n_steps), and otherwise assumes they should go to Mixture.

Related Issue

Reproducing their MWE after changes:

import pymc as pm
import pytensor as pt
with pm.Model():
     ad_nb = pm.HurdleNegativeBinomial('ad_nb', psi=.1, n=4000, p=1 - 5.8 * 1e-5, max_n_steps=10000)
     prior = pm.sample_prior_predictive(samples=100)
print("Done!")

Gives:

WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
Sampling: [ad_nb]
Done!

Checklist

  • Checked that the pre-commit linting/style checks pass

  • Included tests that prove the fix is effective or that the new feature works
    No new tests, confirmed existing tests/distributions/test_mixture.py runs without errors (python 3.12)

  • Added necessary documentation (docstrings and/or example notebooks)
    Probably not warranted

  • If you are a pro: each commit corresponds to a relevant logical change

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7339.org.readthedocs.build/en/7339/

@tjburch tjburch changed the title Redirect kwargs to appropriate distribution Redirect Mixture kwargs to appropriate distribution May 30, 2024
@@ -817,15 +817,19 @@ def _hurdle_mixture(*, name, nonzero_p, nonzero_dist, dtype, **kwargs):

nonzero_p = pt.as_tensor_variable(nonzero_p)
weights = pt.stack([1 - nonzero_p, nonzero_p], axis=-1)

truncated_kwargs = {k: v for k, v in kwargs.items() if k in {"lower", "upper", "max_n_steps"}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

upper is not supposed to be part of hurdle, so I would just do max_n_steps, which can be an explicit kwarg in the signature?

Copy link
Contributor Author

@tjburch tjburch May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Resolved in latest commit.

@tjburch
Copy link
Contributor Author

tjburch commented May 30, 2024

Changing the signature led to some test fails. Looking into it...

]

if max_n_steps:
truncated_dist = Truncated.dist(nonzero_dist, lower=lower, max_n_steps=max_n_steps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a bit cleaner if you just copy the default value to the signature of the hurdle function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine. Wasn't sure if we wanted to avoid doubling up on magic numbers, but if you're good with it, I am. Just made the change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I also don't love it but seems more clean in this case where optional parameters are directed to distinct places. Otherwise we could have explicit truncated_kwargs that the user must pass as a dict.

Anyway the trick of using a Truncated should be a temporary thing, in the long term I would like to make Mixture aware of mixed types and handle it natively

Copy link

codecov bot commented May 31, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.37%. Comparing base (7d15175) to head (1087728).
Report is 114 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7339      +/-   ##
==========================================
+ Coverage   91.25%   92.37%   +1.12%     
==========================================
  Files         102      102              
  Lines       17208    17208              
==========================================
+ Hits        15703    15896     +193     
+ Misses       1505     1312     -193     
Files with missing lines Coverage Δ
pymc/distributions/mixture.py 95.02% <100.00%> (ø)

... and 3 files with indirect coverage changes

@ricardoV94 ricardoV94 merged commit 508a134 into pymc-devs:main May 31, 2024
22 checks passed
@ricardoV94 ricardoV94 changed the title Redirect Mixture kwargs to appropriate distribution Allow customizing truncation max_n_steps in Hurdle Mixtures May 31, 2024
@ricardoV94
Copy link
Member

Thanks @tjburch !

@tjburch tjburch deleted the truncated-kwargs-fix branch May 31, 2024 11:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: passing the 'max_n_steps' parameter as kwarg to HurdleNegativeBinomial distribution does not work
2 participants