Skip to content

Commit

Permalink
Add a flag to LKJCorr to return the unpacked correlation matrix (#7100)
Browse files Browse the repository at this point in the history
* LKJCorr now has an option of returning n by n matrix instead of just the upper triangular part as vector
  • Loading branch information
velochy authored Jan 23, 2024
1 parent 08eaeb1 commit ef1f91f
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 48 deletions.
133 changes: 90 additions & 43 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,6 +1511,7 @@ def rng_fn(cls, rng, n, eta, size):
def _random_corr_matrix(cls, rng, n, eta, flat_size):
# original implementation in R see:
# https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r

beta = eta - 1.0 + n / 2.0
r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=flat_size, random_state=rng) - 1.0
P = np.full((flat_size, n, n), np.eye(n))
Expand All @@ -1537,48 +1538,8 @@ def log_jac_det(self, *args):
return super().log_jac_det(*args).sum(-1)


class LKJCorr(BoundedContinuous):
r"""
The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood.
The LKJ distribution is a prior distribution for correlation matrices.
If eta = 1 this corresponds to the uniform distribution over correlation
matrices. For eta -> oo the LKJ prior approaches the identity matrix.
======== ==============================================
Support Upper triangular matrix with values in [-1, 1]
======== ==============================================
Parameters
----------
n : tensor_like of int
Dimension of the covariance matrix (n > 1).
eta : tensor_like of float
The shape parameter (eta > 0) of the LKJ distribution. eta = 1
implies a uniform distribution of the correlation matrices;
larger values put more weight on matrices with few correlations.
Notes
-----
This implementation only returns the values of the upper triangular
matrix excluding the diagonal. Here is a schematic for n = 5, showing
the indexes of the elements::
[[- 0 1 2 3]
[- - 4 5 6]
[- - - 7 8]
[- - - - 9]
[- - - - -]]
References
----------
.. [LKJ2009] Lewandowski, D., Kurowicka, D. and Joe, H. (2009).
"Generating random correlation matrices based on vines and
extended onion method." Journal of multivariate analysis,
100(9), pp.1989-2001.
"""

# Returns list of upper triangular values
class _LKJCorr(BoundedContinuous):
rv_op = lkjcorr

@classmethod
Expand Down Expand Up @@ -1637,11 +1598,97 @@ def logp(value, n, eta):
)


@_default_transform.register(LKJCorr)
@_default_transform.register(_LKJCorr)
def lkjcorr_default_transform(op, rv):
return MultivariateIntervalTransform(floatX(-1.0), floatX(1.0))


class LKJCorr:
r"""
The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood.
The LKJ distribution is a prior distribution for correlation matrices.
If eta = 1 this corresponds to the uniform distribution over correlation
matrices. For eta -> oo the LKJ prior approaches the identity matrix.
======== ==============================================
Support Upper triangular matrix with values in [-1, 1]
======== ==============================================
Parameters
----------
n : tensor_like of int
Dimension of the covariance matrix (n > 1).
eta : tensor_like of float
The shape parameter (eta > 0) of the LKJ distribution. eta = 1
implies a uniform distribution of the correlation matrices;
larger values put more weight on matrices with few correlations.
return_matrix : bool, default=False
If True, returns the full correlation matrix.
False only returns the values of the upper triangular matrix excluding
diagonal in a single vector of length n(n-1)/2 for memory efficiency
Notes
-----
This is mainly useful if you want the standard deviations to be fixed, as
LKJCholsekyCov is optimized for the case where they come from a distribution.
Examples
--------
.. code:: python
with pm.Model() as model:
# Define the vector of fixed standard deviations
sds = 3*np.ones(10)
corr = pm.LKJCorr(
'corr', eta=4, n=10, return_matrix=True
)
# Define a new MvNormal with the given correlation matrix
vals = sds*pm.MvNormal('vals', mu=np.zeros(10), cov=corr, shape=10)
# Or transform an uncorrelated normal distribution:
vals_raw = pm.Normal('vals_raw', shape=10)
chol = pt.linalg.cholesky(corr)
vals = sds*pt.dot(chol,vals_raw)
# The matrix is internally still sampled as a upper triangular vector
# If you want access to it in matrix form in the trace, add
pm.Deterministic('corr_mat', corr)
References
----------
.. [LKJ2009] Lewandowski, D., Kurowicka, D. and Joe, H. (2009).
"Generating random correlation matrices based on vines and
extended onion method." Journal of multivariate analysis,
100(9), pp.1989-2001.
"""

def __new__(cls, name, n, eta, *, return_matrix=False, **kwargs):
c_vec = _LKJCorr(name, eta=eta, n=n, **kwargs)
if not return_matrix:
return c_vec
else:
return cls.vec_to_corr_mat(c_vec, n)

@classmethod
def dist(cls, n, eta, *, return_matrix=False, **kwargs):
c_vec = _LKJCorr.dist(eta=eta, n=n, **kwargs)
if not return_matrix:
return c_vec
else:
return cls.vec_to_corr_mat(c_vec, n)

@classmethod
def vec_to_corr_mat(cls, vec, n):
tri = pt.zeros(pt.concatenate([vec.shape[:-1], (n, n)]))
tri = pt.subtensor.set_subtensor(tri[(...,) + np.triu_indices(n, 1)], vec)
return tri + pt.moveaxis(tri, -2, -1) + pt.diag(pt.ones(n))


class MatrixNormalRV(RandomVariable):
name = "matrixnormal"
ndim_supp = 2
Expand Down
27 changes: 22 additions & 5 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pymc.distributions.multivariate import (
MultivariateIntervalTransform,
_LKJCholeskyCov,
_LKJCorr,
_OrderedMultinomial,
posdef,
quaddist_matrix,
Expand Down Expand Up @@ -558,7 +559,7 @@ def test_wishart(self, n):
@pytest.mark.parametrize("x,eta,n,lp", LKJ_CASES)
def test_lkjcorr(self, x, eta, n, lp):
with pm.Model() as model:
pm.LKJCorr("lkj", eta=eta, n=n, transform=None)
pm.LKJCorr("lkj", eta=eta, n=n, transform=None, return_matrix=False)

point = {"lkj": x}
decimals = select_by_precision(float64=6, float32=4)
Expand Down Expand Up @@ -1330,7 +1331,7 @@ def test_kronecker_normal_moment(self, mu, covs, size, expected):
)
def test_lkjcorr_moment(self, n, eta, size, expected):
with pm.Model() as model:
pm.LKJCorr("x", n=n, eta=eta, size=size)
pm.LKJCorr("x", n=n, eta=eta, size=size, return_matrix=False)
assert_moment_is_expected(model, expected)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -1465,6 +1466,22 @@ def test_with_cov_rv(

assert prior["mv"].shape == (10, 4, 3)

def test_with_lkjcorr_matrix(
self,
):
with pm.Model() as model:
corr = pm.LKJCorr("corr", n=3, eta=2, return_matrix=True)
pm.Deterministic("corr_mat", corr)
mv = pm.MvNormal("mv", 0.0, cov=corr, size=4)
prior = pm.sample_prior_predictive(samples=10, return_inferencedata=False)

assert prior["corr_mat"].shape == (10, 3, 3) # square
assert (prior["corr_mat"][:, [0, 1, 2], [0, 1, 2]] == 1.0).all() # 1.0 on diagonal
assert (prior["corr_mat"] == prior["corr_mat"].transpose(0, 2, 1)).all() # symmetric
assert (
prior["corr_mat"].max() <= 1.0 and prior["corr_mat"].min() >= -1.0
) # constrained between -1 and 1

def test_issue_3758(self):
np.random.seed(42)
ndim = 50
Expand Down Expand Up @@ -2133,7 +2150,7 @@ class TestOrderedMultinomial(BaseTestDistributionRandom):


class TestLKJCorr(BaseTestDistributionRandom):
pymc_dist = pm.LKJCorr
pymc_dist = _LKJCorr
pymc_dist_params = {"n": 3, "eta": 1.0}
expected_rv_op_params = {"n": 3, "eta": 1.0}

Expand Down Expand Up @@ -2161,7 +2178,7 @@ def ref_rand(size, n, eta):
return (st.beta.rvs(size=(size, shape), a=beta, b=beta) - 0.5) * 2

continuous_random_tester(
pm.LKJCorr,
_LKJCorr,
{
"n": Domain([2, 10, 50], edges=(None, None)),
"eta": Domain([1.0, 10.0, 100.0], edges=(None, None)),
Expand All @@ -2186,7 +2203,7 @@ def ref_rand(size, n, eta):
)
def test_LKJCorr_default_transform(shape):
with pm.Model() as m:
x = pm.LKJCorr("x", n=2, eta=1, shape=shape)
x = pm.LKJCorr("x", n=2, eta=1, shape=shape, return_matrix=False)
assert isinstance(m.rvs_to_transforms[x], MultivariateIntervalTransform)
assert m.logp(sum=False)[0].type.shape == shape[:-1]

Expand Down

0 comments on commit ef1f91f

Please sign in to comment.