From 29a7931d1771a8dd05ce33773fa150a6f0f1e6a7 Mon Sep 17 00:00:00 2001 From: kc611 Date: Sat, 2 Oct 2021 17:48:30 +0530 Subject: [PATCH] Built the LKJCholeskyCovRV within the init function of _LKJCholeskyCov class --- pymc/distributions/multivariate.py | 41 +++++++++++++++++------------- pymc/tests/test_distributions.py | 2 +- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index a8f1d7e6e78..4f805961cff 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -51,7 +51,7 @@ ) from pymc.distributions.dist_math import bound, factln, logpow, multigammaln from pymc.distributions.distribution import Continuous, Discrete -from pymc.distributions.logprob import _logp +from pymc.distributions.logprob import _logp, logpt_sum from pymc.distributions.shape_utils import broadcast_dist_samples_to, to_tuple from pymc.math import kron_diag, kron_dot from pymc.util import UNSET @@ -1126,8 +1126,7 @@ def rng_fn(self, rng, eta, n, sd_dist, size=None): if orig_size is None: samples = samples[0] else: - # FIXME: What is sample_shape, is it supposed to be `dist_shape` - samples = np.reshape(samples, orig_size + sample_shape) + samples = np.reshape(samples, orig_size + dist_shape) return samples @@ -1141,6 +1140,22 @@ def __new__(cls, *args, **kwargs): transform = kwargs.get("transform", UNSET) if transform is UNSET: kwargs["transform"] = cls.default_transform() + + sd_dist = kwargs["sd_dist"] + cls.rv_op = _LKJCholeskyCovRV( + "_lkjcholeskycov", + 2, + (0, 0, 0), + "floatX", + sd_dist=sd_dist, + inplace=False, + ) + + @_logp.register(_LKJCholeskyCovRV) + def logp(op, var, rvs_to_values, *dist_params, **kwargs): + value_var = rvs_to_values.get(var, var) + return cls.logp(value_var, *dist_params, **kwargs) + return super().__new__(cls, *args, **kwargs) @classmethod @@ -1159,15 +1174,6 @@ def dist(cls, eta, n, sd_dist, *args, **kwargs): if "shape" in kwargs: raise ValueError("Invalid parameter: shape.") - # FIXME: The sd_dist info will be lost during `random_make_inplace` rewrite - cls.rv_op = _LKJCholeskyCovRV( - "_lkjcholeskycov", - 2, - (0, 0, 0), - "floatX", - sd_dist=sd_dist, - inplace=False, - ) # if sd_dist.shape.ndim not in [0, 1]: # raise ValueError("Invalid shape for sd_dist.") @@ -1194,13 +1200,13 @@ def logp(value, eta, n, sd_dist): diag_idxs = at.cumsum(at.arange(1, n + 1)) - 1 cumsum = at.cumsum(value ** 2) - variance = at.zeros(n) + variance = at.zeros([n]) variance = at.inc_subtensor(variance[0], value[0] ** 2) variance = at.inc_subtensor(variance[1:], cumsum[diag_idxs[1:]] - cumsum[diag_idxs[:-1]]) sd_vals = at.sqrt(variance) # TODO: Since the sd_dist logp is added independently, - # we can perhaps compose the logp terms in a more clean way + # we can perhaps compose the logp terms in a more clean way logp_sd = _logp(sd_dist.owner.op, sd_vals, {}, *sd_dist.owner.inputs[3:]) logp_sd = logp_sd.sum() corr_diag = value[diag_idxs] / sd_vals @@ -1385,20 +1391,19 @@ def _infer_shape(self, size, dist_params, param_shapes=None): @classmethod def rng_fn(cls, rng, n, eta, size=None): - # TODO: rng is not being used by the stat.sebta.rvs! size = 1 if size is None else size size = size if isinstance(size, tuple) else (size,) # original implementation in R see: # https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r beta = eta - 1.0 + n / 2.0 - r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=size) - 1.0 + r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=size, random_state=rng) - 1.0 P = np.eye(n)[:, :, np.newaxis] * np.ones(size) P[0, 1] = r12 P[1, 1] = np.sqrt(1.0 - r12 ** 2) for mp1 in range(2, n): beta -= 0.5 - y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=size) - z = stats.norm.rvs(loc=0, scale=1, size=(mp1,) + size) + y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=size, random_state=rng) + z = stats.norm.rvs(loc=0, scale=1, size=(mp1,) + size, random_state=rng) z = z / np.sqrt(np.einsum("ij,ij->j", z, z)) P[0:mp1, mp1] = np.sqrt(y) * z P[mp1, mp1] = np.sqrt(1.0 - y) diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index b0555f9b67a..b119ca1b800 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -3331,4 +3331,4 @@ def test_lkjcholeskycov(): pt = np.random.random(120) pt = {"packedL_cholesky-cov-packed__": pt} logp = model.fastlogp(pt) - assert prior["packedL"].shape == (samples,) + dist_shape + assert 0 # Test not complete