-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix error with LKJCorr default transform #7023
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #7023 +/- ##
===========================================
- Coverage 92.19% 56.77% -35.42%
===========================================
Files 101 101
Lines 16893 16892 -1
===========================================
- Hits 15575 9591 -5984
- Misses 1318 7301 +5983
|
@ricardoV94 I want to kick start this PR. In 242be23 I simply removed the condition. I am still unsure where (which line) and how to sum the jabian. Should it be something like the statement above? diff_ndims = log_jac_det.ndim - logp.ndim
log_jac_det = log_jac_detsum(axis=np.arange(-diff_ndims, 0)) ? Or am I totally lost 🙈 ? |
You're correct, and that should be inside the old elif statement, so don't remove it, just put that inside instead of raising. We should add a comment explaining when does that branch get triggered: univariate transform applied to a multivariate RV. Also keep the old link, we should still mention this is not valid for "non full-rank multivariate" distributions (don't know a better name) like Dirichlet or ZeroSumNormal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. We should add a test for the LKJCorr.
We should check that model.compile_logp(sum=False)(model.initial_point())
has the same shape with and without the default transform. Before we introduced this restriction it didn't.
[I will be working on the tests soon ... the end of the year is always a bit hectic 😅 ] |
c49bf6d
to
a732444
Compare
@ricardoV94 all tests passed 🙌 ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should revert the unrelated changes. Your editor or pre-commit is probably more opinionated than the PyMC one
def test_lkjcorr_default_transform(jacobian, n): | ||
with pm.Model() as m: | ||
pm.LKJCorr("Ω_triu", eta=1, n=n, transform=None) | ||
assert m.compile_logp(jacobian=jacobian, sum=False)(m.initial_point())[0] == m.compile_logp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about the test.
Perhaps it's enough to test the default transform is Interval and that calling logp doesn't fail and has the right shape. This can be done in tests/distributions/test_multivariate.py
as a specific regression test for the default transform of this Distribution (if this test had existed we would have had to address it when we did the breaking changes before)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Thanks! Some questions:
Perhaps it's enough to test the default transform is Interval
Do you mean m.compile_logp(jacobian=jacobian, sum=False)(m.initial_point())
is in certain interval? which one?
calling logp doesn't fail and has the right shape.
Don't we have this already?
This can be done in tests/distributions/test_multivariate.py as a specific regression test for the default transform of this Distribution
So you mean we move this test to a new one in tests/distributions/test_multivariate.py
with pm.Model() as m: | ||
pm.Dirichlet("x", [1, 1, 1], transform=tr.log) | ||
assert m.logp(jacobian=jacobian).type.shape == () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add batch dims and test the logp is the expected one.
Something like logp(dirichlet_dist, pt.log(test_value)) + LogTransform.log_jac_det(test_value).sum(-1)
Also make sure it's not just -inf.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit lost. What would be the assert statement ?
I know assert np.all(pm.logp(x, pt.log(test_value)).eval() < np.inf)
✅ , but what is the expected relationship with tr.LogTransform().log_jac_det(test_value).sum(-1)
?
0edd7b4
to
4f07ded
Compare
Dam! I erased the changes trying to revert 🤦 . I will try to bring them back |
ok! seems we are back 😅 |
ec36918
to
0fd7b9e
Compare
Closes #7002
📚 Documentation preview 📚: https://pymc--7023.org.readthedocs.build/en/7023/