diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index c36108aeacc..125959e3d3d 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -25,7 +25,7 @@ import scipy from aesara.assert_op import Assert -from aesara.graph.basic import Apply +from aesara.graph.basic import Apply, Constant from aesara.graph.op import Op from aesara.sparse.basic import sp_sum from aesara.tensor import gammaln, sigmoid @@ -43,7 +43,12 @@ from pymc.aesaraf import floatX, intX from pymc.distributions import transforms -from pymc.distributions.continuous import ChiSquared, Normal, assert_negative_support +from pymc.distributions.continuous import ( + BoundedContinuous, + ChiSquared, + Normal, + assert_negative_support, +) from pymc.distributions.dist_math import ( betaln, check_parameters, @@ -57,7 +62,9 @@ rv_size_is_none, to_tuple, ) +from pymc.distributions.transforms import interval from pymc.math import kron_diag, kron_dot +from pymc.util import UNSET __all__ = [ "MvNormal", @@ -1419,7 +1426,71 @@ def LKJCholeskyCov(name, eta, n, sd_dist, compute_corr=False, store_in_trace=Tru return chol, corr, stds -class LKJCorr(Continuous): +class LKJCorrRV(RandomVariable): + name = "lkjcorr" + ndim_supp = 1 + ndims_params = [0, 0] + dtype = "floatX" + _print_name = ("LKJCorrRV", "\\operatorname{LKJCorrRV}") + + def make_node(self, rng, size, dtype, n, eta): + n = at.as_tensor_variable(n) + if not n.ndim == 0: + raise ValueError("n must be a scalar (ndim=0).") + + eta = at.as_tensor_variable(eta) + if not eta.ndim == 0: + raise ValueError("eta must be a scalar (ndim=0).") + + return super().make_node(rng, size, dtype, n, eta) + + def _shape_from_params(self, dist_params, **kwargs): + n = dist_params[0] + dist_shape = ((n * (n - 1)) // 2,) + return dist_shape + + @classmethod + def rng_fn(cls, rng, n, eta, size): + + # We flatten the size to make operations easier, and then rebuild it + if size is None: + orig_size = None + size = 1 + else: + orig_size = size + size = np.prod(size) + + # original implementation in R see: + # https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r + beta = eta - 1.0 + n / 2.0 + r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=size, random_state=rng) - 1.0 + P = np.full((size, n, n), np.eye(n)) + P[..., 0, 1] = r12 + P[..., 1, 1] = np.sqrt(1.0 - r12 ** 2) + for mp1 in range(2, n): + beta -= 0.5 + y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=size, random_state=rng) + z = stats.norm.rvs(loc=0, scale=1, size=(size, mp1), random_state=rng) + z = z / np.sqrt(np.einsum("ij,ij->i", z, z))[..., np.newaxis] + P[..., 0:mp1, mp1] = np.sqrt(y[..., np.newaxis]) * z + P[..., mp1, mp1] = np.sqrt(1.0 - y) + C = np.einsum("...ji,...jk->...ik", P, P) + + triu_idx = np.triu_indices(n, k=1) + samples = C[..., triu_idx[0], triu_idx[1]] + + if orig_size is None: + samples = samples[0] + else: + dist_shape = (n * (n - 1)) // 2 + samples = np.reshape(samples, (*orig_size, dist_shape)) + return samples + + +lkjcorr = LKJCorrRV() + + +class LKJCorr(BoundedContinuous): r""" The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood. @@ -1461,112 +1532,61 @@ class LKJCorr(Continuous): 100(9), pp.1989-2001. """ - def __init__(self, eta=None, n=None, p=None, transform="interval", *args, **kwargs): - if (p is not None) and (n is not None) and (eta is None): - warnings.warn( - "Parameters to LKJCorr have changed: shape parameter n -> eta " - "dimension parameter p -> n. Please update your code. " - "Automatically re-assigning parameters for backwards compatibility.", - FutureWarning, - ) - self.n = p - self.eta = n - eta = self.eta - n = self.n - elif (n is not None) and (eta is not None) and (p is None): - self.n = n - self.eta = eta - else: - raise ValueError( - "Invalid parameter: please use eta as the shape parameter and " - "n as the dimension parameter." - ) + rv_op = lkjcorr - shape = n * (n - 1) // 2 - self.mean = floatX(np.zeros(shape)) + def __new__(cls, *args, **kwargs): + transform = kwargs.get("transform", UNSET) + if transform is UNSET: + kwargs["transform"] = interval(lambda *args: (floatX(-1.0), floatX(1.0))) + return super().__new__(cls, *args, **kwargs) - if transform == "interval": - transform = transforms.interval(-1, 1) - - super().__init__(shape=shape, transform=transform, *args, **kwargs) - warnings.warn( - "Parameters in LKJCorr have been rename: shape parameter n -> eta " - "dimension parameter p -> n. Please double check your initialization.", - FutureWarning, - ) - self.tri_index = np.zeros([n, n], dtype="int32") - self.tri_index[np.triu_indices(n, k=1)] = np.arange(shape) - self.tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape) - - def _random(self, n, eta, size=None): - size = size if isinstance(size, tuple) else (size,) - # original implementation in R see: - # https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r - beta = eta - 1.0 + n / 2.0 - r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=size) - 1.0 - P = np.eye(n)[:, :, np.newaxis] * np.ones(size) - P[0, 1] = r12 - P[1, 1] = np.sqrt(1.0 - r12 ** 2) - for mp1 in range(2, n): - beta -= 0.5 - y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=size) - z = stats.norm.rvs(loc=0, scale=1, size=(mp1,) + size) - z = z / np.sqrt(np.einsum("ij,ij->j", z, z)) - P[0:mp1, mp1] = np.sqrt(y) * z - P[mp1, mp1] = np.sqrt(1.0 - y) - C = np.einsum("ji...,jk...->...ik", P, P) - triu_idx = np.triu_indices(n, k=1) - return C[..., triu_idx[0], triu_idx[1]] - - def random(self, point=None, size=None): - """ - Draw random values from LKJ distribution. - - Parameters - ---------- - point: dict, optional - Dict of variable values on which random values are to be - conditioned (uses default point if not specified). - size: int, optional - Desired size of random sample (returns one sample if not - specified). - - Returns - ------- - array - """ - # n, eta = draw_values([self.n, self.eta], point=point, size=size) - # size = 1 if size is None else size - # samples = generate_samples(self._random, n, eta, broadcast_shape=(size,)) - # return samples + @classmethod + def dist(cls, n, eta, **kwargs): + n = at.as_tensor_variable(intX(n)) + eta = at.as_tensor_variable(floatX(eta)) + return super().dist([n, eta], **kwargs) - def logp(self, x): + def logp(value, n, eta): """ Calculate log-probability of LKJ distribution at specified value. Parameters ---------- - x: numeric + value: numeric Value for which log-probability is calculated. Returns ------- TensorVariable """ - n = self.n - eta = self.eta - X = x[self.tri_index] - X = at.fill_diagonal(X, 1) + # TODO: Aesara does not have a `triu_indices`, so we can only work with constant + # n (or else find a different expression) + if not isinstance(n, Constant): + raise NotImplementedError("logp only implemented for constant `n`") + n = n.data + shape = n * (n - 1) // 2 + tri_index = np.zeros((n, n), dtype="int32") + tri_index[np.triu_indices(n, k=1)] = np.arange(shape) + tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape) + + value = take(value, tri_index) + value = at.fill_diagonal(value, 1) + + # TODO: _lkj_normalizing_constant currently requires `eta` to be a constant + # this could easily be changed to allow `eta` to be a learnable parameter + if not isinstance(eta, Constant): + raise NotImplementedError("logp only implemented for constant `eta`") + eta = eta.data result = _lkj_normalizing_constant(eta, n) - result += (eta - 1.0) * at.log(det(X)) + result += (eta - 1.0) * at.log(det(value)) return check_parameters( result, - X >= -1, - X <= 1, - matrix_pos_def(X), + value >= -1, + value <= 1, + matrix_pos_def(value), eta > 0, ) diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index 9b377825eda..91488a0fed8 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -2094,8 +2094,7 @@ def test_wishart(self, n): ) @pytest.mark.parametrize("x,eta,n,lp", LKJ_CASES) - @pytest.mark.xfail(reason="Distribution not refactored yet") - def test_lkj(self, x, eta, n, lp): + def test_lkjcorr(self, x, eta, n, lp): with Model() as model: LKJCorr("lkj", eta=eta, n=n, transform=None) diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index 89775bef62d..8ae6fa4710a 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -1779,29 +1779,46 @@ def kronecker_rng_fn(self, size, mu, covs=None, sigma=None, rng=None): ] -class TestScalarParameterSamples(SeededTest): - @pytest.mark.xfail(reason="This distribution has not been refactored for v4") - def test_lkj(self): - for n in [2, 10, 50]: - # pylint: disable=cell-var-from-loop - shape = n * (n - 1) // 2 - - def ref_rand(size, eta): - beta = eta - 1 + n / 2 - return (st.beta.rvs(size=(size, shape), a=beta, b=beta) - 0.5) * 2 - - class TestedLKJCorr(pm.LKJCorr): - def __init__(self, **kwargs): - kwargs.pop("shape", None) - super().__init__(n=n, **kwargs) - - pymc_random( - TestedLKJCorr, - {"eta": Domain([1.0, 10.0, 100.0])}, - size=10000 // n, - ref_rand=ref_rand, - ) +class TestLKJCorr(BaseTestDistributionRandom): + pymc_dist = pm.LKJCorr + pymc_dist_params = {"n": 3, "eta": 1.0} + expected_rv_op_params = {"n": 3, "eta": 1.0} + + sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)] + sizes_expected = [ + (3,), + (3,), + (1, 3), + (1, 3), + (5, 3), + (4, 5, 3), + (2, 4, 2, 3), + ] + + tests_to_run = [ + "check_pymc_params_match_rv_op", + "check_rv_size", + "check_draws_match_expected", + ] + + def check_draws_match_expected(self): + def ref_rand(size, n, eta): + shape = int(n * (n - 1) // 2) + beta = eta - 1 + n / 2 + return (st.beta.rvs(size=(size, shape), a=beta, b=beta) - 0.5) * 2 + pymc_random( + pm.LKJCorr, + { + "n": Domain([2, 10, 50], edges=(None, None)), + "eta": Domain([1.0, 10.0, 100.0], edges=(None, None)), + }, + ref_rand=ref_rand, + size=1000, + ) + + +class TestScalarParameterSamples(SeededTest): @pytest.mark.xfail(reason="This distribution has not been refactored for v4") def test_normalmixture(self): def ref_rand(size, w, mu, sigma):