Skip to content

Commit

Permalink
Refactor LKJCholeskyCov for V4
Browse files Browse the repository at this point in the history
Changes:
* compute_corr now defaults to True
* LKJCholeskyCov now also provides a `.dist` interface
  • Loading branch information
ricardoV94 committed Jan 28, 2022
1 parent eed60c3 commit 1a35a3d
Show file tree
Hide file tree
Showing 9 changed files with 248 additions and 172 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ All of the above apply to:
- ArviZ `plots` and `stats` *wrappers* were removed. The functions are now just available by their original names (see [#4549](https://github.com/pymc-devs/pymc/pull/4471) and `3.11.2` release notes).
- `pm.sample_posterior_predictive(vars=...)` kwarg was removed in favor of `var_names` (see [#4343](https://github.com/pymc-devs/pymc/pull/4343)).
- `ElemwiseCategorical` step method was removed (see [#4701](https://github.com/pymc-devs/pymc/pull/4701))
- `LKJCholeskyCov` `compute_corr` keyword argument is now set to `True` by default (see[#5382](https://github.com/pymc-devs/pymc/pull/5382))

### Ongoing deprecations
- Old API still works in `v4` and has a deprecation warning.
Expand Down
295 changes: 137 additions & 158 deletions pymc/distributions/multivariate.py

Large diffs are not rendered by default.

26 changes: 17 additions & 9 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
RVTransform,
Simplex,
)
from aesara.tensor.subtensor import advanced_set_subtensor1

__all__ = [
"RVTransform",
Expand Down Expand Up @@ -97,22 +96,31 @@ def log_jac_det(self, value, *inputs):


class CholeskyCovPacked(RVTransform):
"""
Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the
log scale
"""

name = "cholesky-cov-packed"

def __init__(self, param_extract_fn):
self.param_extract_fn = param_extract_fn
def __init__(self, n):
"""
Parameters
----------
n: int
Number of diagonal entries in the LKJCholeskyCov distribution
"""
self.diag_idxs = at.arange(1, n + 1).cumsum() - 1

def backward(self, value, *inputs):
diag_idxs = self.param_extract_fn(inputs)
return advanced_set_subtensor1(value, at.exp(value[diag_idxs]), diag_idxs)
return at.set_subtensor(value[..., self.diag_idxs], at.exp(value[..., self.diag_idxs]))

def forward(self, value, *inputs):
diag_idxs = self.param_extract_fn(inputs)
return advanced_set_subtensor1(value, at.log(value[diag_idxs]), diag_idxs)
return at.set_subtensor(value[..., self.diag_idxs], at.log(value[..., self.diag_idxs]))

def log_jac_det(self, value, *inputs):
diag_idxs = self.param_extract_fn(inputs)
return at.sum(value[diag_idxs])
return at.sum(value[..., self.diag_idxs], axis=-1)


class Chain(RVTransform):
Expand Down
4 changes: 3 additions & 1 deletion pymc/tests/sampler_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ def make_model(cls):
with pm.Model() as model:
sd_mu = np.array([1, 2, 3, 4, 5])
sd_dist = pm.LogNormal.dist(mu=sd_mu, sigma=sd_mu / 10.0, size=5)
chol_packed = pm.LKJCholeskyCov("chol_packed", eta=3, n=5, sd_dist=sd_dist)
chol_packed = pm.LKJCholeskyCov(
"chol_packed", eta=3, n=5, sd_dist=sd_dist, compute_corr=False
)
chol = pm.expand_packed_triangular(5, chol_packed, lower=True)
cov = at.dot(chol, chol.T)
stds = at.sqrt(at.diag(cov))
Expand Down
37 changes: 37 additions & 0 deletions pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3352,3 +3352,40 @@ def test_censored_invalid_dist(self):
match="The dist dist was already registered in the current model",
):
x = pm.Censored("x", registered_dist, lower=None, upper=None)


class TestLKJCholeskCov:
def test_dist(self):
sd_dist = pm.Exponential.dist(1, size=(10, 3))
x = pm.LKJCholeskyCov.dist(n=3, eta=1, sd_dist=sd_dist, size=10, compute_corr=False)
assert x.eval().shape == (10, 6)

sd_dist = pm.Exponential.dist(1, size=3)
chol, corr, stds = pm.LKJCholeskyCov.dist(n=3, eta=1, sd_dist=sd_dist)
assert chol.eval().shape == (3, 3)
assert corr.eval().shape == (3, 3)
assert stds.eval().shape == (3,)

def test_sd_dist_distribution(self):
with pm.Model() as m:
sd_dist = at.constant([1, 2, 3])
with pytest.raises(TypeError, match="sd_dist must be a Distribution variable"):
x = pm.LKJCholeskyCov("x", n=3, eta=1, sd_dist=sd_dist)

def test_sd_dist_registered(self):
with pm.Model() as m:
sd_dist = pm.Exponential("sd_dist", 1, size=3)
with pytest.raises(
ValueError, match="The dist sd_dist was already registered in the current model"
):
x = pm.LKJCholeskyCov("x", n=3, eta=1, sd_dist=sd_dist)

def test_no_warning_logp(self):
# Check that calling logp of a model with LKJCholeskyCov does not issue any warnings
# due to the RandomVariable in the graph
with pm.Model() as m:
sd_dist = pm.Exponential.dist(1, size=3)
x = pm.LKJCholeskyCov("x", n=3, eta=1, sd_dist=sd_dist)
with pytest.warns(None) as record:
m.logpt()
assert not record
51 changes: 50 additions & 1 deletion pymc/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def random_polyagamma(*args, **kwargs):
from pymc.distributions.discrete import _OrderedLogistic, _OrderedProbit
from pymc.distributions.dist_math import clipped_beta_rvs
from pymc.distributions.logprob import logp
from pymc.distributions.multivariate import _OrderedMultinomial, quaddist_matrix
from pymc.distributions.multivariate import (
_LKJCholeskyCov,
_OrderedMultinomial,
quaddist_matrix,
)
from pymc.distributions.shape_utils import to_tuple
from pymc.tests.helpers import SeededTest, select_by_precision
from pymc.tests.test_distributions import (
Expand Down Expand Up @@ -1867,6 +1871,43 @@ def ref_rand(size, n, eta):
)


class TestLKJCholeskyCov(BaseTestDistributionRandom):
pymc_dist = _LKJCholeskyCov
pymc_dist_params = {"eta": 1.0, "n": 3, "sd_dist": pm.Constant.dist([0.5, 1.0, 2.0])}
expected_rv_op_params = {"n": 3, "eta": 1.0, "sd_dist": pm.Constant.dist([0.5, 1.0, 2.0])}
size = None

sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
sizes_expected = [
(6,),
(6,),
(1, 6),
(1, 6),
(5, 6),
(4, 5, 6),
(2, 4, 2, 6),
]

tests_to_run = [
"check_rv_size",
"check_draws_match_expected",
]

def check_rv_size(self):
for size, expected in zip(self.sizes_to_check, self.sizes_expected):
sd_dist = pm.Exponential.dist(1, size=(*to_tuple(size), 3))
pymc_rv = self.pymc_dist.dist(n=3, eta=1, sd_dist=sd_dist, size=size)
expected_symbolic = tuple(pymc_rv.shape.eval())
actual = pymc_rv.eval().shape
assert actual == expected_symbolic == expected

def check_draws_match_expected(self):
# TODO: Find better comparison:
rng = aesara.shared(self.get_random_state(reset=True))
x = _LKJCholeskyCov.dist(n=2, eta=10_000, sd_dist=pm.Constant.dist([0.5, 2.0]), rng=rng)
assert np.all(np.abs(x.eval() - np.array([0.5, 0, 2.0])) < 0.01)


class TestScalarParameterSamples(SeededTest):
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
def test_normalmixture(self):
Expand Down Expand Up @@ -2346,9 +2387,11 @@ def test_with_chol_rv(self, sample_shape, dist_shape, mu_shape):
with pm.Model() as model:
mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
sd_dist = pm.Exponential.dist(1.0, shape=3)
# pylint: disable=unpacking-non-sequence
chol, corr, stds = pm.LKJCholeskyCov(
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
)
# pylint: enable=unpacking-non-sequence
mv = pm.MvNormal("mv", mu, chol=chol, shape=dist_shape)
prior = pm.sample_prior_predictive(samples=sample_shape)

Expand All @@ -2363,9 +2406,11 @@ def test_with_cov_rv(self, sample_shape, dist_shape, mu_shape):
with pm.Model() as model:
mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
sd_dist = pm.Exponential.dist(1.0, shape=3)
# pylint: disable=unpacking-non-sequence
chol, corr, stds = pm.LKJCholeskyCov(
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
)
# pylint: enable=unpacking-non-sequence
mv = pm.MvNormal("mv", mu, cov=pm.math.dot(chol, chol.T), shape=dist_shape)
prior = pm.sample_prior_predictive(samples=sample_shape)

Expand Down Expand Up @@ -2457,9 +2502,11 @@ def test_with_chol_rv(self, sample_shape, dist_shape, mu_shape):
with pm.Model() as model:
mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
sd_dist = pm.Exponential.dist(1.0, shape=3)
# pylint: disable=unpacking-non-sequence
chol, corr, stds = pm.LKJCholeskyCov(
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
)
# pylint: enable=unpacking-non-sequence
mv = pm.MvGaussianRandomWalk("mv", mu, chol=chol, shape=dist_shape)
prior = pm.sample_prior_predictive(samples=sample_shape)

Expand All @@ -2475,9 +2522,11 @@ def test_with_cov_rv(self, sample_shape, dist_shape, mu_shape):
with pm.Model() as model:
mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
sd_dist = pm.Exponential.dist(1.0, shape=3)
# pylint: disable=unpacking-non-sequence
chol, corr, stds = pm.LKJCholeskyCov(
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
)
# pylint: enable=unpacking-non-sequence
mv = pm.MvGaussianRandomWalk("mv", mu, cov=pm.math.dot(chol, chol.T), shape=dist_shape)
prior = pm.sample_prior_predictive(samples=sample_shape)

Expand Down
3 changes: 2 additions & 1 deletion pymc/tests/test_idata_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,15 +332,16 @@ def test_missing_data_model(self):
assert inference_data.log_likelihood["y_observed"].shape == (2, 100, 3)

@pytest.mark.xfal(reason="Multivariate partial observed RVs not implemented for V4")
@pytest.mark.xfail(reason="LKJCholeskyCov not refactored for v4")
def test_mv_missing_data_model(self):
data = ma.masked_values([[1, 2], [2, 2], [-1, 4], [2, -1], [-1, -1]], value=-1)

model = pm.Model()
with model:
mu = pm.Normal("mu", 0, 1, size=2)
sd_dist = pm.HalfNormal.dist(1.0)
# pylint: disable=unpacking-non-sequence
chol, *_ = pm.LKJCholeskyCov("chol_cov", n=2, eta=1, sd_dist=sd_dist, compute_corr=True)
# pylint: enable=unpacking-non-sequence
y = pm.MvNormal("y", mu=mu, chol=chol, observed=data)
inference_data = pm.sample(100, chains=2, return_inferencedata=True)

Expand Down
2 changes: 1 addition & 1 deletion pymc/tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def build_toy_dataset(N, K):
mu.append(pm.Normal("mu%i" % i, 0, 10, shape=D))
packed_chol.append(
pm.LKJCholeskyCov(
"chol_cov_%i" % i, eta=2, n=D, sd_dist=pm.HalfNormal.dist(2.5)
"chol_cov_%i" % i, eta=2, n=D, sd_dist=pm.HalfNormal.dist(2.5, size=D)
)
)
chol.append(pm.expand_packed_triangular(D, packed_chol[i], lower=True))
Expand Down
1 change: 0 additions & 1 deletion pymc/tests/test_posteriors.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ class TestNUTSNormalLong(sf.NutsFixture, sf.NormalFixture):
atol = 0.001


@pytest.mark.xfail(reason="LKJCholeskyCov not refactored for v4")
class TestNUTSLKJCholeskyCov(sf.NutsFixture, sf.LKJCholeskyCovFixture):
n_samples = 2000
tune = 1000
Expand Down

0 comments on commit 1a35a3d

Please sign in to comment.