Skip to content

Commit

Permalink
Refactor LKJCorr distribution to V4
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 24, 2022
1 parent 475ffbf commit d382369
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 91 deletions.
198 changes: 109 additions & 89 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import scipy

from aesara.assert_op import Assert
from aesara.graph.basic import Apply
from aesara.graph.basic import Apply, Constant
from aesara.graph.op import Op
from aesara.sparse.basic import sp_sum
from aesara.tensor import gammaln, sigmoid
Expand All @@ -43,7 +43,12 @@

from pymc.aesaraf import floatX, intX
from pymc.distributions import transforms
from pymc.distributions.continuous import ChiSquared, Normal, assert_negative_support
from pymc.distributions.continuous import (
BoundedContinuous,
ChiSquared,
Normal,
assert_negative_support,
)
from pymc.distributions.dist_math import (
betaln,
check_parameters,
Expand All @@ -57,7 +62,9 @@
rv_size_is_none,
to_tuple,
)
from pymc.distributions.transforms import interval
from pymc.math import kron_diag, kron_dot
from pymc.util import UNSET

__all__ = [
"MvNormal",
Expand Down Expand Up @@ -1415,7 +1422,71 @@ def LKJCholeskyCov(name, eta, n, sd_dist, compute_corr=False, store_in_trace=Tru
return chol, corr, stds


class LKJCorr(Continuous):
class LKJCorrRV(RandomVariable):
name = "lkjcorr"
ndim_supp = 1
ndims_params = [0, 0]
dtype = "floatX"
_print_name = ("LKJCorrRV", "\\operatorname{LKJCorrRV}")

def make_node(self, rng, size, dtype, n, eta):
n = at.as_tensor_variable(n)
if not n.ndim == 0:
raise ValueError("n must be a scalar (ndim=0).")

eta = at.as_tensor_variable(eta)
if not eta.ndim == 0:
raise ValueError("eta must be a scalar (ndim=0).")

return super().make_node(rng, size, dtype, n, eta)

def _shape_from_params(self, dist_params, **kwargs):
n = dist_params[0]
dist_shape = ((n * (n - 1)) // 2,)
return dist_shape

@classmethod
def rng_fn(cls, rng, n, eta, size):

# We flatten the size to make operations easier, and then rebuild it
if size is None:
orig_size = None
size = 1
else:
orig_size = size
size = np.prod(size)

# original implementation in R see:
# https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r
beta = eta - 1.0 + n / 2.0
r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=size, random_state=rng) - 1.0
P = np.full((size, n, n), np.eye(n))
P[..., 0, 1] = r12
P[..., 1, 1] = np.sqrt(1.0 - r12 ** 2)
for mp1 in range(2, n):
beta -= 0.5
y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=size, random_state=rng)
z = stats.norm.rvs(loc=0, scale=1, size=(size, mp1), random_state=rng)
z = z / np.sqrt(np.einsum("ij,ij->i", z, z))[..., np.newaxis]
P[..., 0:mp1, mp1] = np.sqrt(y[..., np.newaxis]) * z
P[..., mp1, mp1] = np.sqrt(1.0 - y)
C = np.einsum("...ji,...jk->...ik", P, P)

triu_idx = np.triu_indices(n, k=1)
samples = C[..., triu_idx[0], triu_idx[1]]

if orig_size is None:
samples = samples[0]
else:
dist_shape = (n * (n - 1)) // 2
samples = np.reshape(samples, (*orig_size, dist_shape))
return samples


lkjcorr = LKJCorrRV()


class LKJCorr(BoundedContinuous):
r"""
The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood.
Expand Down Expand Up @@ -1457,112 +1528,61 @@ class LKJCorr(Continuous):
100(9), pp.1989-2001.
"""

def __init__(self, eta=None, n=None, p=None, transform="interval", *args, **kwargs):
if (p is not None) and (n is not None) and (eta is None):
warnings.warn(
"Parameters to LKJCorr have changed: shape parameter n -> eta "
"dimension parameter p -> n. Please update your code. "
"Automatically re-assigning parameters for backwards compatibility.",
FutureWarning,
)
self.n = p
self.eta = n
eta = self.eta
n = self.n
elif (n is not None) and (eta is not None) and (p is None):
self.n = n
self.eta = eta
else:
raise ValueError(
"Invalid parameter: please use eta as the shape parameter and "
"n as the dimension parameter."
)
rv_op = lkjcorr

shape = n * (n - 1) // 2
self.mean = floatX(np.zeros(shape))
def __new__(cls, *args, **kwargs):
transform = kwargs.get("transform", UNSET)
if transform is UNSET:
kwargs["transform"] = interval(lambda *args: (floatX(-1.0), floatX(1.0)))
return super().__new__(cls, *args, **kwargs)

if transform == "interval":
transform = transforms.interval(-1, 1)

super().__init__(shape=shape, transform=transform, *args, **kwargs)
warnings.warn(
"Parameters in LKJCorr have been rename: shape parameter n -> eta "
"dimension parameter p -> n. Please double check your initialization.",
FutureWarning,
)
self.tri_index = np.zeros([n, n], dtype="int32")
self.tri_index[np.triu_indices(n, k=1)] = np.arange(shape)
self.tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape)

def _random(self, n, eta, size=None):
size = size if isinstance(size, tuple) else (size,)
# original implementation in R see:
# https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r
beta = eta - 1.0 + n / 2.0
r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=size) - 1.0
P = np.eye(n)[:, :, np.newaxis] * np.ones(size)
P[0, 1] = r12
P[1, 1] = np.sqrt(1.0 - r12 ** 2)
for mp1 in range(2, n):
beta -= 0.5
y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=size)
z = stats.norm.rvs(loc=0, scale=1, size=(mp1,) + size)
z = z / np.sqrt(np.einsum("ij,ij->j", z, z))
P[0:mp1, mp1] = np.sqrt(y) * z
P[mp1, mp1] = np.sqrt(1.0 - y)
C = np.einsum("ji...,jk...->...ik", P, P)
triu_idx = np.triu_indices(n, k=1)
return C[..., triu_idx[0], triu_idx[1]]

def random(self, point=None, size=None):
"""
Draw random values from LKJ distribution.
Parameters
----------
point: dict, optional
Dict of variable values on which random values are to be
conditioned (uses default point if not specified).
size: int, optional
Desired size of random sample (returns one sample if not
specified).
Returns
-------
array
"""
# n, eta = draw_values([self.n, self.eta], point=point, size=size)
# size = 1 if size is None else size
# samples = generate_samples(self._random, n, eta, broadcast_shape=(size,))
# return samples
@classmethod
def dist(cls, n, eta, **kwargs):
n = at.as_tensor_variable(intX(n))
eta = at.as_tensor_variable(floatX(eta))
return super().dist([n, eta], **kwargs)

def logp(self, x):
def logp(value, n, eta):
"""
Calculate log-probability of LKJ distribution at specified
value.
Parameters
----------
x: numeric
value: numeric
Value for which log-probability is calculated.
Returns
-------
TensorVariable
"""
n = self.n
eta = self.eta

X = x[self.tri_index]
X = at.fill_diagonal(X, 1)
# TODO: Aesara does not have a `triu_indices`, so we can only work with constant
# n (or else find a different expression)
if not isinstance(n, Constant):
raise NotImplementedError("logp only implemented for constant `n`")

n = n.data
shape = n * (n - 1) // 2
tri_index = np.zeros((n, n), dtype="int32")
tri_index[np.triu_indices(n, k=1)] = np.arange(shape)
tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape)

value = take(value, tri_index)
value = at.fill_diagonal(value, 1)

# TODO: _lkj_normalizing_constant currently requires `eta` to be a constant
# this could easily be changed to allow `eta` to be a learnable parameter
if not isinstance(eta, Constant):
raise NotImplementedError("logp only implemented for constant `eta`")
eta = eta.data
result = _lkj_normalizing_constant(eta, n)
result += (eta - 1.0) * at.log(det(X))
result += (eta - 1.0) * at.log(det(value))
return check_parameters(
result,
X >= -1,
X <= 1,
matrix_pos_def(X),
value >= -1,
value <= 1,
matrix_pos_def(value),
eta > 0,
)

Expand Down
3 changes: 1 addition & 2 deletions pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2094,8 +2094,7 @@ def test_wishart(self, n):
)

@pytest.mark.parametrize("x,eta,n,lp", LKJ_CASES)
@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_lkj(self, x, eta, n, lp):
def test_lkjcorr(self, x, eta, n, lp):
with Model() as model:
LKJCorr("lkj", eta=eta, n=n, transform=None)

Expand Down
39 changes: 39 additions & 0 deletions pymc/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1763,6 +1763,45 @@ def kronecker_rng_fn(self, size, mu, covs=None, sigma=None, rng=None):
]


class TestLKJCorr(BaseTestDistributionRandom):
pymc_dist = pm.LKJCorr
pymc_dist_params = {"n": 3, "eta": 1.0}
expected_rv_op_params = {"n": 3, "eta": 1.0}

sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
sizes_expected = [
(3,),
(3,),
(1, 3),
(1, 3),
(5, 3),
(4, 5, 3),
(2, 4, 2, 3),
]

tests_to_run = [
"check_pymc_params_match_rv_op",
"check_rv_size",
"check_draws_match_expected",
]

def check_draws_match_expected(self):
def ref_rand(size, n, eta):
shape = int(n * (n - 1) // 2)
beta = eta - 1 + n / 2
return (st.beta.rvs(size=(size, shape), a=beta, b=beta) - 0.5) * 2

pymc_random(
pm.LKJCorr,
{
"n": Domain([2, 10, 50], edges=(None, None)),
"eta": Domain([1.0, 10.0, 100.0], edges=(None, None)),
},
ref_rand=ref_rand,
size=1000,
)


class TestScalarParameterSamples(SeededTest):
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
def test_lkj(self):
Expand Down

0 comments on commit d382369

Please sign in to comment.