-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix failing default transform for LKJCorr
#7065
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #7065 +/- ##
=======================================
Coverage 92.19% 92.20%
=======================================
Files 101 101
Lines 16893 16901 +8
=======================================
+ Hits 15575 15584 +9
+ Misses 1318 1317 -1
|
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
LKJCorr
Is it good on your end? Asking because it's marked as a draft |
@ricardoV94 With the new suggestion the test fails locally with the NotImplementedError: Univariate transform MultivariateIntervalTransform cannot be applied to multivariate lkjcorr_rv{1, (0, 0), floatX, False} 🤔 |
Where is that check coming from? We might need to add some meta-info to the Transform |
The problem is the logp of the distribution is incorrectly implemented. It's returning a scalar instead of a vector of |
Yeah! It is failing locally. It's good that you caught up on this with the test! Do you think there is an "easy" fix? |
We should add a You can parametrize the test to have two cases pymc/tests/distributions/test_multivariate.py Lines 1048 to 1055 in 9b4bf2a
Also this test shouldn't be in the |
It's not trivial, it requires thinking carefully about batch dimensions, like we did in this PR: #6897 |
Also could you reintroduce the change from the other PR where we always run this check instead of being in the else branch? pymc/pymc/logprob/transform_value.py Lines 128 to 133 in 04a03b5
|
Added the suggested changes :) |
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
We should add a check in the logp similar to this: pymc/pymc/distributions/multivariate.py Lines 1249 to 1252 in e67a317
Should be a NotImplementedError though |
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
We now have two tests failing because of the new tests/distributions/test_multivariate.py::TestMoments::test_lkjcorr_moment[3-1-1-expected2] FAILED [ 87%]
tests/distributions/test_multivariate.py::TestMoments::test_lkjcorr_moment[5-1-size3-expected3] FAILED [ 87%] |
Yup |
@ricardoV94 we are back to 🟢 :) |
Thank you for all your help @ricardoV94 ❤️ |
Thanks @juanitorduz |
Closes #7002
Wt take a different direction from #7023
📚 Documentation preview 📚: https://pymc--7065.org.readthedocs.build/en/7065/