-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a flag to LKJCorr to return the unpacked correlation matrix #7100
Conversation
…the upper triangular part as vector
|
sd = pm.Exponential("sd", 1.0, shape=3) | ||
|
||
corr = pm.LKJCorr("corr", n=3, eta=2, return_matrix=True) | ||
# pylint: enable=unpacking-non-sequence |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# pylint: enable=unpacking-non-sequence |
mv = pm.MvNormal("mv", mu, cov=sd * (sd * corr).T, size=4) | ||
prior = pm.sample_prior_predictive(samples=10, return_inferencedata=False) | ||
|
||
assert prior["mv"].shape == (10, 4, 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to directly assert the values of corr
are correct
pymc/distributions/multivariate.py
Outdated
return _LKJCorr(name, eta=eta, n=n, **kwargs) | ||
else: | ||
c_vec = _LKJCorr(name + "_raw", eta=eta, n=n, **kwargs) | ||
return pm.Deterministic(name, cls.vec_to_corr_mat(c_vec, n)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about the name change and wrapping in a Deterministic by default, but not completely against it either
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The argument for having the packed form is that it saves memory. This actually more than doubles it because we save both the packed form and the dense one. We could leave to the users to wrap in a Deterministic if they need to have it after sampling.
OTOH this is similar to what LKJ does by default now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe check the discussion #6828 about setting internal deterministics/coords?
Reminds me I need to finish that PR
pymc/distributions/multivariate.py
Outdated
def lkjcorr_default_transform(op, rv): | ||
return MultivariateIntervalTransform(floatX(-1.0), floatX(1.0)) | ||
|
||
|
||
# Thin wrapper around _LKJCorr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments like this have a tendency to get left behind. I would exclude or add it inside the class definition
# Thin wrapper around _LKJCorr |
Thanks for picking this up @velochy! |
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #7100 +/- ##
=======================================
Coverage 92.21% 92.21%
=======================================
Files 101 101
Lines 16912 16901 -11
=======================================
- Hits 15595 15586 -9
+ Misses 1317 1315 -2
|
pymc/distributions/multivariate.py
Outdated
implies a uniform distribution of the correlation matrices; | ||
larger values put more weight on matrices with few correlations. | ||
return_matrix : bool, default=False | ||
If True, returns the full correllation matrix. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If True, returns the full correllation matrix. | |
If True, returns the full correlation matrix. |
9a64569
to
98a7f48
Compare
Made changes corresponding to all the comments so far. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good! Left some minor suggestions
pymc/distributions/multivariate.py
Outdated
|
||
@classmethod | ||
def dist(cls, n, eta, *, return_matrix=True, **kwargs): | ||
# compute Cholesky decomposition |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# compute Cholesky decomposition |
pymc/distributions/multivariate.py
Outdated
|
||
def __new__(cls, name, n, eta, *, return_matrix=False, **kwargs): | ||
if not return_matrix: | ||
return _LKJCorr(name, eta=eta, n=n, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can just call once and then branch in the return now that they're the same
pymc/distributions/multivariate.py
Outdated
return_matrix : bool, default=False | ||
If True, returns the full correlation matrix. | ||
False only returns the values of the upper triangular matrix excluding | ||
diagonal in a single vector of length n(n-1)/2 for backwards compatibility |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
diagonal in a single vector of length n(n-1)/2 for backwards compatibility | |
diagonal in a single vector of length n(n-1)/2 for memory efficiency |
98a7f48
to
ab63678
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this looks like a great improvement in UX
Description
Provide a wrapper around previous LKJCorr that reshapes the output to a full correllation matrix instead of the somewhat awkward vector of upper-triangular values.
This is based on the discussion with @ricardoV94 and @jessegrabowski in https://discourse.pymc.io/t/using-lkjcorr-together-with-mvnormal/13606/27 .
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7100.org.readthedocs.build/en/7100/