Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error with LKJCorr default transform #7023

Closed
wants to merge 0 commits into from

Conversation

juanitorduz
Copy link
Contributor

@juanitorduz juanitorduz commented Nov 21, 2023

Closes #7002


📚 Documentation preview 📚: https://pymc--7023.org.readthedocs.build/en/7023/

@juanitorduz juanitorduz marked this pull request as draft November 21, 2023 19:34
Copy link

codecov bot commented Nov 21, 2023

Codecov Report

Merging #7023 (ec36918) into main (2e05854) will decrease coverage by 35.42%.
Report is 1 commits behind head on main.
The diff coverage is 25.00%.

Additional details and impacted files

Impacted file tree graph

@@             Coverage Diff             @@
##             main    #7023       +/-   ##
===========================================
- Coverage   92.19%   56.77%   -35.42%     
===========================================
  Files         101      101               
  Lines       16893    16892        -1     
===========================================
- Hits        15575     9591     -5984     
- Misses       1318     7301     +5983     
Files Coverage Δ
pymc/logprob/transform_value.py 70.80% <25.00%> (-22.95%) ⬇️

... and 75 files with indirect coverage changes

@juanitorduz
Copy link
Contributor Author

@ricardoV94 I want to kick start this PR. In 242be23 I simply removed the condition. I am still unsure where (which line) and how to sum the jabian. Should it be something like the statement above?

diff_ndims = log_jac_det.ndim - logp.ndim
log_jac_det = log_jac_detsum(axis=np.arange(-diff_ndims, 0))

?

Or am I totally lost 🙈 ?

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 22, 2023

You're correct, and that should be inside the old elif statement, so don't remove it, just put that inside instead of raising.

We should add a comment explaining when does that branch get triggered: univariate transform applied to a multivariate RV.

Also keep the old link, we should still mention this is not valid for "non full-rank multivariate" distributions (don't know a better name) like Dirichlet or ZeroSumNormal

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. We should add a test for the LKJCorr.

We should check that model.compile_logp(sum=False)(model.initial_point()) has the same shape with and without the default transform. Before we introduced this restriction it didn't.

pymc/logprob/transform_value.py Outdated Show resolved Hide resolved
@juanitorduz
Copy link
Contributor Author

[I will be working on the tests soon ... the end of the year is always a bit hectic 😅 ]

@ricardoV94 ricardoV94 changed the title Fix: LKJCorr default transform Fix error with LKJCorr default transform Dec 9, 2023
@juanitorduz
Copy link
Contributor Author

@ricardoV94 all tests passed 🙌 !

@juanitorduz juanitorduz marked this pull request as ready for review December 12, 2023 13:22
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should revert the unrelated changes. Your editor or pre-commit is probably more opinionated than the PyMC one

tests/distributions/test_transform.py Outdated Show resolved Hide resolved
def test_lkjcorr_default_transform(jacobian, n):
with pm.Model() as m:
pm.LKJCorr("Ω_triu", eta=1, n=n, transform=None)
assert m.compile_logp(jacobian=jacobian, sum=False)(m.initial_point())[0] == m.compile_logp(
Copy link
Member

@ricardoV94 ricardoV94 Dec 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about the test.

Perhaps it's enough to test the default transform is Interval and that calling logp doesn't fail and has the right shape. This can be done in tests/distributions/test_multivariate.py as a specific regression test for the default transform of this Distribution (if this test had existed we would have had to address it when we did the breaking changes before)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Thanks! Some questions:

Perhaps it's enough to test the default transform is Interval

Do you mean m.compile_logp(jacobian=jacobian, sum=False)(m.initial_point()) is in certain interval? which one?

calling logp doesn't fail and has the right shape.

Don't we have this already?

This can be done in tests/distributions/test_multivariate.py as a specific regression test for the default transform of this Distribution

So you mean we move this test to a new one in tests/distributions/test_multivariate.py

with pm.Model() as m:
pm.Dirichlet("x", [1, 1, 1], transform=tr.log)
assert m.logp(jacobian=jacobian).type.shape == ()
Copy link
Member

@ricardoV94 ricardoV94 Dec 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add batch dims and test the logp is the expected one.

Something like logp(dirichlet_dist, pt.log(test_value)) + LogTransform.log_jac_det(test_value).sum(-1)

Also make sure it's not just -inf.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit lost. What would be the assert statement ?
I know assert np.all(pm.logp(x, pt.log(test_value)).eval() < np.inf) ✅ , but what is the expected relationship with tr.LogTransform().log_jac_det(test_value).sum(-1) ?

@juanitorduz juanitorduz force-pushed the issue_7002 branch 2 times, most recently from 0edd7b4 to 4f07ded Compare December 12, 2023 15:52
@juanitorduz
Copy link
Contributor Author

Dam! I erased the changes trying to revert 🤦 . I will try to bring them back

@juanitorduz
Copy link
Contributor Author

ok! seems we are back 😅

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: LKJCorr default transform raises error
2 participants