-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow customizing truncation max_n_steps in Hurdle Mixtures #7339
Conversation
pymc/distributions/mixture.py
Outdated
@@ -817,15 +817,19 @@ def _hurdle_mixture(*, name, nonzero_p, nonzero_dist, dtype, **kwargs): | |||
|
|||
nonzero_p = pt.as_tensor_variable(nonzero_p) | |||
weights = pt.stack([1 - nonzero_p, nonzero_p], axis=-1) | |||
|
|||
truncated_kwargs = {k: v for k, v in kwargs.items() if k in {"lower", "upper", "max_n_steps"}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
upper is not supposed to be part of hurdle, so I would just do max_n_steps, which can be an explicit kwarg in the signature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Resolved in latest commit.
Changing the signature led to some test fails. Looking into it... |
pymc/distributions/mixture.py
Outdated
] | ||
|
||
if max_n_steps: | ||
truncated_dist = Truncated.dist(nonzero_dist, lower=lower, max_n_steps=max_n_steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps a bit cleaner if you just copy the default value to the signature of the hurdle function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fine. Wasn't sure if we wanted to avoid doubling up on magic numbers, but if you're good with it, I am. Just made the change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I also don't love it but seems more clean in this case where optional parameters are directed to distinct places. Otherwise we could have explicit truncated_kwargs that the user must pass as a dict.
Anyway the trick of using a Truncated should be a temporary thing, in the long term I would like to make Mixture aware of mixed types and handle it natively
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #7339 +/- ##
==========================================
+ Coverage 91.25% 92.37% +1.12%
==========================================
Files 102 102
Lines 17208 17208
==========================================
+ Hits 15703 15896 +193
+ Misses 1505 1312 -193
|
Thanks @tjburch ! |
Description
Separates kwargs going to
_hurdle_mixture
by name. Gathers any which should redirect toTruncated
(lower
,upper
, andmax_n_steps
), and otherwise assumes they should go toMixture
.Related Issue
Reproducing their MWE after changes:
Gives:
Checklist
Checked that the pre-commit linting/style checks pass
Included tests that prove the fix is effective or that the new feature works
No new tests, confirmed existing
tests/distributions/test_mixture.py
runs without errors (python 3.12)Added necessary documentation (docstrings and/or example notebooks)
Probably not warranted
If you are a pro: each commit corresponds to a relevant logical change
Type of change
📚 Documentation preview 📚: https://pymc--7339.org.readthedocs.build/en/7339/