From 88d34db14c115ece38254eda5467e7d614ce0b54 Mon Sep 17 00:00:00 2001 From: kc611 Date: Wed, 14 Jul 2021 19:08:23 +0530 Subject: [PATCH] Added shape inferring in LKJCholeskyCov --- pymc3/distributions/multivariate.py | 7 ++++++- pymc3/tests/test_distributions.py | 9 ++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index 5c9b56be4b8..6b61aa37f1f 100644 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -1040,6 +1040,11 @@ def _lkj_normalizing_constant(eta, n): class _LKJCholeskyCovRV(RandomVariable): + def _shape_from_params(self, dist_params, **kwargs): + n = dist_params[1] + dist_shape = ((n * (n + 1)) // 2,) + return dist_shape + def __init__(self, *args, sd_dist=None, **kwargs): self.sd_dist = sd_dist self._print_name = _print_name = ("_LKJCholeskyCov", "\\operatorname{_LKJCholeskyCov}") @@ -1152,7 +1157,7 @@ def dist(cls, eta, n, sd_dist, *args, **kwargs): cls.rv_op = _LKJCholeskyCovRV( "_lkjcholeskycov", - 0, + 2, (0, 0, 0), "floatX", sd_dist=sd_dist, diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 43f26e79fb7..674c0adc7ec 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -3140,14 +3140,13 @@ def test_lkjcholeskycov(): dist_shape = ((D * (D + 1)) // 2,) with pm.Model() as model: - sd_dist = pm.HalfCauchy.dist(beta=2.5, size=(D)) + sd_dist = pm.HalfCauchy.dist(beta=2.5) packedL = pm.LKJCholeskyCov("packedL", eta=2, n=D, sd_dist=sd_dist) + with model: + prior = pm.sample() + pt = np.random.random(120) pt = {"packedL_cholesky-cov-packed__": pt} logp = model.fastlogp(pt) - - with model: - prior = pm.sample_prior_predictive(5) - assert prior["packedL"].shape == (samples,) + dist_shape