From ef1f91f2811a2aab897cb8bfd616e9d81af8b61d Mon Sep 17 00:00:00 2001 From: Margus Niitsoo Date: Tue, 23 Jan 2024 16:10:43 +0200 Subject: [PATCH] Add a flag to LKJCorr to return the unpacked correlation matrix (#7100) * LKJCorr now has an option of returning n by n matrix instead of just the upper triangular part as vector --- pymc/distributions/multivariate.py | 133 +++++++++++++++-------- tests/distributions/test_multivariate.py | 27 ++++- 2 files changed, 112 insertions(+), 48 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 03ea4f3a5da..2f81ef12739 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -1511,6 +1511,7 @@ def rng_fn(cls, rng, n, eta, size): def _random_corr_matrix(cls, rng, n, eta, flat_size): # original implementation in R see: # https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r + beta = eta - 1.0 + n / 2.0 r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=flat_size, random_state=rng) - 1.0 P = np.full((flat_size, n, n), np.eye(n)) @@ -1537,48 +1538,8 @@ def log_jac_det(self, *args): return super().log_jac_det(*args).sum(-1) -class LKJCorr(BoundedContinuous): - r""" - The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood. - - The LKJ distribution is a prior distribution for correlation matrices. - If eta = 1 this corresponds to the uniform distribution over correlation - matrices. For eta -> oo the LKJ prior approaches the identity matrix. - - ======== ============================================== - Support Upper triangular matrix with values in [-1, 1] - ======== ============================================== - - Parameters - ---------- - n : tensor_like of int - Dimension of the covariance matrix (n > 1). - eta : tensor_like of float - The shape parameter (eta > 0) of the LKJ distribution. eta = 1 - implies a uniform distribution of the correlation matrices; - larger values put more weight on matrices with few correlations. - - Notes - ----- - This implementation only returns the values of the upper triangular - matrix excluding the diagonal. Here is a schematic for n = 5, showing - the indexes of the elements:: - - [[- 0 1 2 3] - [- - 4 5 6] - [- - - 7 8] - [- - - - 9] - [- - - - -]] - - - References - ---------- - .. [LKJ2009] Lewandowski, D., Kurowicka, D. and Joe, H. (2009). - "Generating random correlation matrices based on vines and - extended onion method." Journal of multivariate analysis, - 100(9), pp.1989-2001. - """ - +# Returns list of upper triangular values +class _LKJCorr(BoundedContinuous): rv_op = lkjcorr @classmethod @@ -1637,11 +1598,97 @@ def logp(value, n, eta): ) -@_default_transform.register(LKJCorr) +@_default_transform.register(_LKJCorr) def lkjcorr_default_transform(op, rv): return MultivariateIntervalTransform(floatX(-1.0), floatX(1.0)) +class LKJCorr: + r""" + The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood. + + The LKJ distribution is a prior distribution for correlation matrices. + If eta = 1 this corresponds to the uniform distribution over correlation + matrices. For eta -> oo the LKJ prior approaches the identity matrix. + + ======== ============================================== + Support Upper triangular matrix with values in [-1, 1] + ======== ============================================== + + Parameters + ---------- + n : tensor_like of int + Dimension of the covariance matrix (n > 1). + eta : tensor_like of float + The shape parameter (eta > 0) of the LKJ distribution. eta = 1 + implies a uniform distribution of the correlation matrices; + larger values put more weight on matrices with few correlations. + return_matrix : bool, default=False + If True, returns the full correlation matrix. + False only returns the values of the upper triangular matrix excluding + diagonal in a single vector of length n(n-1)/2 for memory efficiency + + Notes + ----- + This is mainly useful if you want the standard deviations to be fixed, as + LKJCholsekyCov is optimized for the case where they come from a distribution. + + Examples + -------- + .. code:: python + + with pm.Model() as model: + + # Define the vector of fixed standard deviations + sds = 3*np.ones(10) + + corr = pm.LKJCorr( + 'corr', eta=4, n=10, return_matrix=True + ) + + # Define a new MvNormal with the given correlation matrix + vals = sds*pm.MvNormal('vals', mu=np.zeros(10), cov=corr, shape=10) + + # Or transform an uncorrelated normal distribution: + vals_raw = pm.Normal('vals_raw', shape=10) + chol = pt.linalg.cholesky(corr) + vals = sds*pt.dot(chol,vals_raw) + + # The matrix is internally still sampled as a upper triangular vector + # If you want access to it in matrix form in the trace, add + pm.Deterministic('corr_mat', corr) + + + References + ---------- + .. [LKJ2009] Lewandowski, D., Kurowicka, D. and Joe, H. (2009). + "Generating random correlation matrices based on vines and + extended onion method." Journal of multivariate analysis, + 100(9), pp.1989-2001. + """ + + def __new__(cls, name, n, eta, *, return_matrix=False, **kwargs): + c_vec = _LKJCorr(name, eta=eta, n=n, **kwargs) + if not return_matrix: + return c_vec + else: + return cls.vec_to_corr_mat(c_vec, n) + + @classmethod + def dist(cls, n, eta, *, return_matrix=False, **kwargs): + c_vec = _LKJCorr.dist(eta=eta, n=n, **kwargs) + if not return_matrix: + return c_vec + else: + return cls.vec_to_corr_mat(c_vec, n) + + @classmethod + def vec_to_corr_mat(cls, vec, n): + tri = pt.zeros(pt.concatenate([vec.shape[:-1], (n, n)])) + tri = pt.subtensor.set_subtensor(tri[(...,) + np.triu_indices(n, 1)], vec) + return tri + pt.moveaxis(tri, -2, -1) + pt.diag(pt.ones(n)) + + class MatrixNormalRV(RandomVariable): name = "matrixnormal" ndim_supp = 2 diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index b85fb65cda5..7cdba547283 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -34,6 +34,7 @@ from pymc.distributions.multivariate import ( MultivariateIntervalTransform, _LKJCholeskyCov, + _LKJCorr, _OrderedMultinomial, posdef, quaddist_matrix, @@ -558,7 +559,7 @@ def test_wishart(self, n): @pytest.mark.parametrize("x,eta,n,lp", LKJ_CASES) def test_lkjcorr(self, x, eta, n, lp): with pm.Model() as model: - pm.LKJCorr("lkj", eta=eta, n=n, transform=None) + pm.LKJCorr("lkj", eta=eta, n=n, transform=None, return_matrix=False) point = {"lkj": x} decimals = select_by_precision(float64=6, float32=4) @@ -1330,7 +1331,7 @@ def test_kronecker_normal_moment(self, mu, covs, size, expected): ) def test_lkjcorr_moment(self, n, eta, size, expected): with pm.Model() as model: - pm.LKJCorr("x", n=n, eta=eta, size=size) + pm.LKJCorr("x", n=n, eta=eta, size=size, return_matrix=False) assert_moment_is_expected(model, expected) @pytest.mark.parametrize( @@ -1465,6 +1466,22 @@ def test_with_cov_rv( assert prior["mv"].shape == (10, 4, 3) + def test_with_lkjcorr_matrix( + self, + ): + with pm.Model() as model: + corr = pm.LKJCorr("corr", n=3, eta=2, return_matrix=True) + pm.Deterministic("corr_mat", corr) + mv = pm.MvNormal("mv", 0.0, cov=corr, size=4) + prior = pm.sample_prior_predictive(samples=10, return_inferencedata=False) + + assert prior["corr_mat"].shape == (10, 3, 3) # square + assert (prior["corr_mat"][:, [0, 1, 2], [0, 1, 2]] == 1.0).all() # 1.0 on diagonal + assert (prior["corr_mat"] == prior["corr_mat"].transpose(0, 2, 1)).all() # symmetric + assert ( + prior["corr_mat"].max() <= 1.0 and prior["corr_mat"].min() >= -1.0 + ) # constrained between -1 and 1 + def test_issue_3758(self): np.random.seed(42) ndim = 50 @@ -2133,7 +2150,7 @@ class TestOrderedMultinomial(BaseTestDistributionRandom): class TestLKJCorr(BaseTestDistributionRandom): - pymc_dist = pm.LKJCorr + pymc_dist = _LKJCorr pymc_dist_params = {"n": 3, "eta": 1.0} expected_rv_op_params = {"n": 3, "eta": 1.0} @@ -2161,7 +2178,7 @@ def ref_rand(size, n, eta): return (st.beta.rvs(size=(size, shape), a=beta, b=beta) - 0.5) * 2 continuous_random_tester( - pm.LKJCorr, + _LKJCorr, { "n": Domain([2, 10, 50], edges=(None, None)), "eta": Domain([1.0, 10.0, 100.0], edges=(None, None)), @@ -2186,7 +2203,7 @@ def ref_rand(size, n, eta): ) def test_LKJCorr_default_transform(shape): with pm.Model() as m: - x = pm.LKJCorr("x", n=2, eta=1, shape=shape) + x = pm.LKJCorr("x", n=2, eta=1, shape=shape, return_matrix=False) assert isinstance(m.rvs_to_transforms[x], MultivariateIntervalTransform) assert m.logp(sum=False)[0].type.shape == shape[:-1]