-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Allow for batched alpha in StickBreakingWeights
#6042
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## main #6042 +/- ##
==========================================
- Coverage 89.27% 87.44% -1.83%
==========================================
Files 72 72
Lines 12890 12946 +56
==========================================
- Hits 11507 11321 -186
- Misses 1383 1625 +242
|
Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com>
ab3d6e2 to
7fde39a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good start. I found some issues about the interpretation of size in conjugation with batched alphas
alpha to take batched data for StickBreakingWeightsalpha in StickBreakingWeights
larryshamalama
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @purna135! :)
pymc/tests/test_distributions.py
Outdated
| def test_stickbreakingweights_logp(self, value, alpha, K, logp): | ||
| with Model() as model: | ||
| def test_stickbreakingweights_logp(self, alpha, K, _compile_stickbreakingweights_logpdf): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it acceptable to combine test_stickbreakingweights_logp and test_stickbreakingweights_vectorized in a single test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If test_stickbreakingweights_vectorized fails, it would point to shapes not being handled properly to form batches. So, lets keep them separate to have better isolation of tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now I separate the test for batched alpha
I need some assistance calculating |
|
Left a comment above about the fixture. Also don't forget @larryshamalama remark above that we should test the pymc/pymc/tests/test_distributions_moments.py Lines 1171 to 1174 in 7af102d
Should be enough to test with a vector of two alphas, maybe one of those that is already tested for single alpha (reusing the same k) and the other being an extreme value like alpha=1 or alpha=0 (if that's valid), which might have a very simple moment. |
|
Yes, I got the test for pymc/pymc/tests/test_distributions_moments.py Lines 1147 to 1151 in 7af102d
|
You can check what the moment is for two distinct single alphas, and it should be the same for a batched alpha that has those two values. |
|
Ok got it now, do I need to create a separate test for batched alpha as we did in |
Nope, you can just add it as an extra condition in the existing tests. Moments is less sensitive than logp so we can keep it bundled together |
|
Looks complete to me. The failing test is unrelated. Just asked if @larryshamalama could leave a review as well. |
larryshamalama
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great @purna135! Just one question for myself 😅
| (np.arange(1, 7, dtype="float64").reshape(2, 3), 5), | ||
| ], | ||
| ) | ||
| def test_stickbreakingweights_vectorized(self, alpha, K, stickbreakingweights_logpdf): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is stickbreakingweights_logpdf passed as an argument here via the fixture sharing decorator of the function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's how fixtures work
What is this PR about?
Addressing #5383
This enables
StickBreakingWeight'salphato accept batched data (>2D), make theinfer_shapework with batched data, and fix therng_fnby broadcasting alpha to K.Checklist
Major / Breaking Changes
Bugfixes / New features
StickBreakingWeightsnow supports batchedalphaparametersDocs / Maintenance