-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow for batched alpha
in StickBreakingWeights
#6042
Conversation
Codecov Report
@@ Coverage Diff @@
## main #6042 +/- ##
==========================================
- Coverage 89.27% 87.44% -1.83%
==========================================
Files 72 72
Lines 12890 12946 +56
==========================================
- Hits 11507 11321 -186
- Misses 1383 1625 +242
|
Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com>
ab3d6e2
to
7fde39a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good start. I found some issues about the interpretation of size in conjugation with batched alphas
alpha
to take batched data for StickBreakingWeights
alpha
in StickBreakingWeights
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @purna135! :)
pymc/tests/test_distributions.py
Outdated
def test_stickbreakingweights_logp(self, value, alpha, K, logp): | ||
with Model() as model: | ||
def test_stickbreakingweights_logp(self, alpha, K, _compile_stickbreakingweights_logpdf): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it acceptable to combine test_stickbreakingweights_logp
and test_stickbreakingweights_vectorized
in a single test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If test_stickbreakingweights_vectorized
fails, it would point to shapes not being handled properly to form batches. So, lets keep them separate to have better isolation of tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now I separate the test for batched alpha
I need some assistance calculating |
Left a comment above about the fixture. Also don't forget @larryshamalama remark above that we should test the pymc/pymc/tests/test_distributions_moments.py Lines 1171 to 1174 in 7af102d
Should be enough to test with a vector of two alphas, maybe one of those that is already tested for single alpha (reusing the same k) and the other being an extreme value like alpha=1 or alpha=0 (if that's valid), which might have a very simple moment. |
Yes, I got the test for pymc/pymc/tests/test_distributions_moments.py Lines 1147 to 1151 in 7af102d
|
You can check what the moment is for two distinct single alphas, and it should be the same for a batched alpha that has those two values. |
Ok got it now, do I need to create a separate test for batched alpha as we did in |
Nope, you can just add it as an extra condition in the existing tests. Moments is less sensitive than logp so we can keep it bundled together |
Looks complete to me. The failing test is unrelated. Just asked if @larryshamalama could leave a review as well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great @purna135! Just one question for myself 😅
(np.arange(1, 7, dtype="float64").reshape(2, 3), 5), | ||
], | ||
) | ||
def test_stickbreakingweights_vectorized(self, alpha, K, stickbreakingweights_logpdf): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is stickbreakingweights_logpdf
passed as an argument here via the fixture sharing decorator of the function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's how fixtures work
What is this PR about?
Addressing #5383
This enables
StickBreakingWeight
'salpha
to accept batched data (>2D), make theinfer_shape
work with batched data, and fix therng_fn
by broadcasting alpha to K.Checklist
Major / Breaking Changes
Bugfixes / New features
StickBreakingWeights
now supports batchedalpha
parametersDocs / Maintenance