Skip to content

Commit

Permalink
Built the LKJCholeskyCovRV within the init function of _LKJCholeskyCo…
Browse files Browse the repository at this point in the history
…v class
  • Loading branch information
kc611 committed Oct 2, 2021
1 parent bc6ef58 commit 29a7931
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
41 changes: 23 additions & 18 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
)
from pymc.distributions.dist_math import bound, factln, logpow, multigammaln
from pymc.distributions.distribution import Continuous, Discrete
from pymc.distributions.logprob import _logp
from pymc.distributions.logprob import _logp, logpt_sum
from pymc.distributions.shape_utils import broadcast_dist_samples_to, to_tuple
from pymc.math import kron_diag, kron_dot
from pymc.util import UNSET
Expand Down Expand Up @@ -1126,8 +1126,7 @@ def rng_fn(self, rng, eta, n, sd_dist, size=None):
if orig_size is None:
samples = samples[0]
else:
# FIXME: What is sample_shape, is it supposed to be `dist_shape`
samples = np.reshape(samples, orig_size + sample_shape)
samples = np.reshape(samples, orig_size + dist_shape)

return samples

Expand All @@ -1141,6 +1140,22 @@ def __new__(cls, *args, **kwargs):
transform = kwargs.get("transform", UNSET)
if transform is UNSET:
kwargs["transform"] = cls.default_transform()

sd_dist = kwargs["sd_dist"]
cls.rv_op = _LKJCholeskyCovRV(
"_lkjcholeskycov",
2,
(0, 0, 0),
"floatX",
sd_dist=sd_dist,
inplace=False,
)

@_logp.register(_LKJCholeskyCovRV)
def logp(op, var, rvs_to_values, *dist_params, **kwargs):
value_var = rvs_to_values.get(var, var)
return cls.logp(value_var, *dist_params, **kwargs)

return super().__new__(cls, *args, **kwargs)

@classmethod
Expand All @@ -1159,15 +1174,6 @@ def dist(cls, eta, n, sd_dist, *args, **kwargs):
if "shape" in kwargs:
raise ValueError("Invalid parameter: shape.")

# FIXME: The sd_dist info will be lost during `random_make_inplace` rewrite
cls.rv_op = _LKJCholeskyCovRV(
"_lkjcholeskycov",
2,
(0, 0, 0),
"floatX",
sd_dist=sd_dist,
inplace=False,
)
# if sd_dist.shape.ndim not in [0, 1]:
# raise ValueError("Invalid shape for sd_dist.")

Expand All @@ -1194,13 +1200,13 @@ def logp(value, eta, n, sd_dist):

diag_idxs = at.cumsum(at.arange(1, n + 1)) - 1
cumsum = at.cumsum(value ** 2)
variance = at.zeros(n)
variance = at.zeros([n])
variance = at.inc_subtensor(variance[0], value[0] ** 2)
variance = at.inc_subtensor(variance[1:], cumsum[diag_idxs[1:]] - cumsum[diag_idxs[:-1]])
sd_vals = at.sqrt(variance)

# TODO: Since the sd_dist logp is added independently,
# we can perhaps compose the logp terms in a more clean way
# we can perhaps compose the logp terms in a more clean way
logp_sd = _logp(sd_dist.owner.op, sd_vals, {}, *sd_dist.owner.inputs[3:])
logp_sd = logp_sd.sum()
corr_diag = value[diag_idxs] / sd_vals
Expand Down Expand Up @@ -1385,20 +1391,19 @@ def _infer_shape(self, size, dist_params, param_shapes=None):

@classmethod
def rng_fn(cls, rng, n, eta, size=None):
# TODO: rng is not being used by the stat.sebta.rvs!
size = 1 if size is None else size
size = size if isinstance(size, tuple) else (size,)
# original implementation in R see:
# https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r
beta = eta - 1.0 + n / 2.0
r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=size) - 1.0
r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=size, random_state=rng) - 1.0
P = np.eye(n)[:, :, np.newaxis] * np.ones(size)
P[0, 1] = r12
P[1, 1] = np.sqrt(1.0 - r12 ** 2)
for mp1 in range(2, n):
beta -= 0.5
y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=size)
z = stats.norm.rvs(loc=0, scale=1, size=(mp1,) + size)
y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=size, random_state=rng)
z = stats.norm.rvs(loc=0, scale=1, size=(mp1,) + size, random_state=rng)
z = z / np.sqrt(np.einsum("ij,ij->j", z, z))
P[0:mp1, mp1] = np.sqrt(y) * z
P[mp1, mp1] = np.sqrt(1.0 - y)
Expand Down
2 changes: 1 addition & 1 deletion pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3331,4 +3331,4 @@ def test_lkjcholeskycov():
pt = np.random.random(120)
pt = {"packedL_cholesky-cov-packed__": pt}
logp = model.fastlogp(pt)
assert prior["packedL"].shape == (samples,) + dist_shape
assert 0 # Test not complete

0 comments on commit 29a7931

Please sign in to comment.