Skip to content

Commit

Permalink
Allow batched parameters in MvNormal and MvStudentT distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Sep 12, 2023
1 parent c3efb11 commit fc7ed02
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 95 deletions.
95 changes: 31 additions & 64 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,20 +123,17 @@ def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs):

if cov is not None:
cov = pt.as_tensor_variable(cov)
if cov.ndim != 2:
raise ValueError("cov must be two dimensional.")
if cov.ndim < 2:
raise ValueError("cov must be at least two dimensional.")

Check warning on line 127 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L127

Added line #L127 was not covered by tests
elif tau is not None:
tau = pt.as_tensor_variable(tau)
if tau.ndim != 2:
raise ValueError("tau must be two dimensional.")
# TODO: What's the correct order/approach (in the non-square case)?
# `pytensor.tensor.nlinalg.tensorinv`?
if tau.ndim < 2:
raise ValueError("tau must be at least two dimensional.")

Check warning on line 131 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L131

Added line #L131 was not covered by tests
cov = matrix_inverse(tau)
else:
# TODO: What's the correct order/approach (in the non-square case)?
chol = pt.as_tensor_variable(chol)
if chol.ndim != 2:
raise ValueError("chol must be two dimensional.")
if chol.ndim < 2:
raise ValueError("chol must be at least two dimensional.")

Check warning on line 136 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L136

Added line #L136 was not covered by tests

# tag as lower triangular to enable pytensor rewrites of chol(l.l') -> l
chol.tag.lower_triangular = True
Expand All @@ -145,10 +142,10 @@ def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs):
return cov


def quaddist_parse(value, mu, cov, mat_type="cov"):
def quaddist_chol(value, mu, cov):
"""Compute (x - mu).T @ Sigma^-1 @ (x - mu) and the logdet of Sigma."""
if value.ndim > 2 or value.ndim == 0:
raise ValueError("Invalid dimension for value: %s" % value.ndim)
if value.ndim == 0:
raise ValueError(f"Invalid dimension for value: {value.ndim}")

Check warning on line 148 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L148

Added line #L148 was not covered by tests
if value.ndim == 1:
onedim = True
value = value[None, :]
Expand All @@ -157,42 +154,21 @@ def quaddist_parse(value, mu, cov, mat_type="cov"):

delta = value - mu
chol_cov = nan_lower_cholesky(cov)
if mat_type != "tau":
dist, logdet, ok = quaddist_chol(delta, chol_cov)
else:
dist, logdet, ok = quaddist_tau(delta, chol_cov)
if onedim:
return dist[0], logdet, ok

return dist, logdet, ok


def quaddist_chol(delta, chol_mat):
diag = pt.diag(chol_mat)
diag = pt.diagonal(chol_cov, axis1=-2, axis2=-1)
# Check if the covariance matrix is positive definite.
ok = pt.all(diag > 0)
ok = pt.all(diag > 0, axis=-1)
# If not, replace the diagonal. We return -inf later, but
# need to prevent solve_lower from throwing an exception.
chol_cov = pt.switch(ok, chol_mat, 1)

delta_trans = solve_lower(chol_cov, delta.T).T
chol_cov = pt.switch(ok[..., None, None], chol_cov, 1)
delta_trans = solve_lower(chol_cov, delta, b_ndim=1)
quaddist = (delta_trans**2).sum(axis=-1)
logdet = pt.sum(pt.log(diag))
return quaddist, logdet, ok

logdet = pt.log(diag).sum(axis=-1)

def quaddist_tau(delta, chol_mat):
diag = pt.nlinalg.diag(chol_mat)
# Check if the precision matrix is positive definite.
ok = pt.all(diag > 0)
# If not, replace the diagonal. We return -inf later, but
# need to prevent solve_lower from throwing an exception.
chol_tau = pt.switch(ok, chol_mat, 1)

delta_trans = pt.dot(delta, chol_tau)
quaddist = (delta_trans**2).sum(axis=-1)
logdet = -pt.sum(pt.log(diag))
return quaddist, logdet, ok
if onedim:
return quaddist[0], logdet, ok
else:
return quaddist, logdet, ok


class MvNormal(Continuous):
Expand Down Expand Up @@ -290,7 +266,7 @@ def logp(value, mu, cov):
-------
TensorVariable
"""
quaddist, logdet, ok = quaddist_parse(value, mu, cov)
quaddist, logdet, ok = quaddist_chol(value, mu, cov)
k = floatX(value.shape[-1])
norm = -0.5 * k * pm.floatX(np.log(2 * np.pi))
return check_parameters(
Expand All @@ -307,22 +283,6 @@ class MvStudentTRV(RandomVariable):
dtype = "floatX"
_print_name = ("MvStudentT", "\\operatorname{MvStudentT}")

def make_node(self, rng, size, dtype, nu, mu, cov):
nu = pt.as_tensor_variable(nu)
if not nu.ndim == 0:
raise ValueError("nu must be a scalar (ndim=0).")

return super().make_node(rng, size, dtype, nu, mu, cov)

def __call__(self, nu, mu=None, cov=None, size=None, **kwargs):
dtype = pytensor.config.floatX if self.dtype == "floatX" else self.dtype

if mu is None:
mu = np.array([0.0], dtype=dtype)
if cov is None:
cov = np.array([[1.0]], dtype=dtype)
return super().__call__(nu, mu, cov, size=size, **kwargs)

def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
Expand All @@ -333,14 +293,21 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None):

@classmethod
def rng_fn(cls, rng, nu, mu, cov, size):
if size is None:
# When size is implicit, we need to broadcast parameters correctly,
# so that the MvNormal draws and the chisquare draws have the same number of batch dimensions.
# nu broadcasts mu and cov
if np.ndim(nu) > max(mu.ndim - 1, cov.ndim - 2):
_, mu, cov = broadcast_params((nu, mu, cov), ndims_params=cls.ndims_params)

Check warning on line 301 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L301

Added line #L301 was not covered by tests
# nu is broadcasted by either mu or cov
elif np.ndim(nu) < max(mu.ndim - 1, cov.ndim - 2):
nu, _, _ = broadcast_params((nu, mu, cov), ndims_params=cls.ndims_params)

mv_samples = multivariate_normal.rng_fn(rng=rng, mean=np.zeros_like(mu), cov=cov, size=size)

# Take chi2 draws and add an axis of length 1 to the right for correct broadcasting below
chi2_samples = np.sqrt(rng.chisquare(nu, size=size) / nu)[..., None]

if size:
mu = np.broadcast_to(mu, size + (mu.shape[-1],))

return (mv_samples / chi2_samples) + mu


Expand Down Expand Up @@ -390,7 +357,7 @@ class MvStudentT(Continuous):
rv_op = mv_studentt

@classmethod
def dist(cls, nu, Sigma=None, mu=None, scale=None, tau=None, chol=None, lower=True, **kwargs):
def dist(cls, nu, *, Sigma=None, mu, scale=None, tau=None, chol=None, lower=True, **kwargs):
cov = kwargs.pop("cov", None)
if cov is not None:
warnings.warn(
Expand Down Expand Up @@ -432,7 +399,7 @@ def logp(value, nu, mu, scale):
-------
TensorVariable
"""
quaddist, logdet, ok = quaddist_parse(value, mu, scale)
quaddist, logdet, ok = quaddist_chol(value, mu, scale)
k = floatX(value.shape[-1])

norm = gammaln((nu + k) / 2.0) - gammaln(nu / 2.0) - 0.5 * k * pt.log(nu * np.pi)
Expand Down
27 changes: 12 additions & 15 deletions tests/distributions/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,21 +333,18 @@ def test_list_multivariate_components_deterministic_weights(self, weights, compo
assert not repetitions

# Test logp
# MvNormal logp is currently limited to 2d values
expectation = pytest.raises(ValueError) if mix_eval.ndim > 2 else does_not_raise()
with expectation:
mix_logp_eval = logp(mix, mix_eval).eval()
assert mix_logp_eval.shape == expected_shape[:-1]
bcast_weights = np.broadcast_to(weights, (*expected_shape[:-1], 2))
expected_logp = np.stack(
(
logp(components[0], mix_eval).eval(),
logp(components[1], mix_eval).eval(),
),
axis=-1,
)[bcast_weights == 1]
expected_logp = expected_logp.reshape(expected_shape[:-1])
assert np.allclose(mix_logp_eval, expected_logp)
mix_logp_eval = logp(mix, mix_eval).eval()
assert mix_logp_eval.shape == expected_shape[:-1]
bcast_weights = np.broadcast_to(weights, (*expected_shape[:-1], 2))
expected_logp = np.stack(
(
logp(components[0], mix_eval).eval(),
logp(components[1], mix_eval).eval(),
),
axis=-1,
)[bcast_weights == 1]
expected_logp = expected_logp.reshape(expected_shape[:-1])
assert np.allclose(mix_logp_eval, expected_logp)

def test_component_choice_random(self):
"""Test that mixture choices change over evaluations"""
Expand Down
113 changes: 97 additions & 16 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,36 @@ def test_mvnormal_init_fail(self):
with pytest.raises(ValueError):
x = pm.MvNormal("x", mu=np.zeros(3), cov=np.eye(3), tau=np.eye(3), size=3)

@pytest.mark.parametrize("batch_mu", (False, True))
@pytest.mark.parametrize("batch_cov", (False, True))
@pytest.mark.parametrize("use_tau", (False, True))
def test_mvnormal_batched_dims(self, batch_mu, batch_cov, use_tau):
def ref_logp_core(value, mu, cov):
return st.multivariate_normal.logpdf(value, mu, cov)

ref_logp = np.vectorize(ref_logp_core, signature="(a),(a),(a,a)->()")

mu = np.arange(5 * 3 * 2).reshape(5, 3, 2) + 1
cov = np.eye(2) * mu[..., None]
value = mu - np.mean(mu)

if not batch_mu:
mu = mu[0, 0]
assert mu.ndim == 1
if not batch_cov:
cov = cov[0, 0]
assert cov.ndim == 2

if use_tau:
dist = pm.MvNormal.dist(mu=mu, tau=np.linalg.inv(cov))
else:
dist = pm.MvNormal.dist(mu=mu, cov=cov)

np.testing.assert_allclose(
pm.logp(dist, value).eval(),
ref_logp(value, mu, cov),
)

@pytest.mark.parametrize("n", [1, 2, 3])
def test_matrixnormal(self, n):
mat_scale = 1e3 # To reduce logp magnitude
Expand Down Expand Up @@ -472,6 +502,40 @@ def test_mvt(self, n):
extra_args={"size": 2},
)

@pytest.mark.parametrize("batch_nu", (False, True))
@pytest.mark.parametrize("batch_mu", (False, True))
@pytest.mark.parametrize("batch_cov", (False, True))
@pytest.mark.parametrize("use_tau", (False, True))
def test_mvt_batched_dims(self, batch_nu, batch_mu, batch_cov, use_tau):
def ref_logp_core(value, nu, mu, cov):
return st.multivariate_t.logpdf(value, mu, cov, df=nu)

ref_logp = np.vectorize(ref_logp_core, signature="(a),(),(a),(a,a)->()")

nu = np.arange(5 * 3).reshape(5, 3) + 2
mu = np.arange(5 * 3 * 2).reshape(5, 3, 2) + 1
cov = np.eye(2) * mu[..., None]
value = mu - np.mean(mu)

if not batch_nu:
nu = nu[0, 0]
if not batch_mu:
mu = mu[0, 0]
assert mu.ndim == 1
if not batch_cov:
cov = cov[0, 0]
assert cov.ndim == 2

if use_tau:
dist = pm.MvStudentT.dist(nu=nu, mu=mu, tau=np.linalg.inv(cov))
else:
dist = pm.MvStudentT.dist(nu=nu, mu=mu, cov=cov)

np.testing.assert_allclose(
pm.logp(dist, value).eval(),
ref_logp(value, nu, mu, cov),
)

@pytest.mark.parametrize("n", [2, 3])
def test_wishart(self, n):
with pytest.warns(UserWarning, match="Wishart distribution can currently not be used"):
Expand Down Expand Up @@ -1038,8 +1102,7 @@ def test_mv_normal_moment(self, mu, cov, size, expected):
with pm.Model() as model:
x = pm.MvNormal("x", mu=mu, cov=cov, size=size)

# MvNormal logp is only implemented for up to 2D variables
assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3)
assert_moment_is_expected(model, expected)

@pytest.mark.parametrize(
"shape, n_zerosum_axes, expected",
Expand Down Expand Up @@ -1109,8 +1172,7 @@ def test_mvstudentt_moment(self, nu, mu, cov, size, expected):
with pm.Model() as model:
x = pm.MvStudentT("x", nu=nu, mu=mu, scale=cov, size=size)

# MvStudentT logp is only implemented for up to 2D variables
assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3)
assert_moment_is_expected(model, expected)

@pytest.mark.parametrize(
"mu, rowchol, colchol, size, expected",
Expand Down Expand Up @@ -1670,21 +1732,10 @@ def mvstudentt_rng_fn(self, size, nu, mu, scale, rng):
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
"check_rv_size",
"check_errors",
"check_mu_broadcast_helper",
"check_batched_nu",
]

def check_errors(self):
msg = "nu must be a scalar (ndim=0)."
with pm.Model():
with pytest.raises(ValueError, match=re.escape(msg)):
mvstudentt = pm.MvStudentT(
"mvstudentt",
nu=np.array([1, 2]),
mu=np.ones(2),
scale=np.full((2, 2), np.ones(2)),
)

def check_mu_broadcast_helper(self):
"""Test that mu is broadcasted to the shape of cov"""
x = pm.MvStudentT.dist(nu=4, mu=1, scale=np.eye(3))
Expand All @@ -1708,6 +1759,36 @@ def check_mu_broadcast_helper(self):
# mu = x.owner.inputs[4]
# assert mu.eval().shape == (10, 2, 3)

def check_batched_nu(self):
rng = np.random.default_rng(sum(map(ord, "batched_nu")))
a = (
pm.draw(
pm.MvStudentT.dist(nu=2, mu=[1, 2, 3], cov=np.eye(3), size=(5_000,)),
random_seed=rng,
)
.std(-1)
.mean()
)
b = (
pm.draw(
pm.MvStudentT.dist(nu=30, mu=[1, 2, 3], cov=np.eye(3), size=(5_000,)),
random_seed=rng,
)
.std(-1)
.mean()
)
ab = (
pm.draw(
pm.MvStudentT.dist(nu=[2, 30], mu=[1, 2, 3], cov=np.eye(3), size=(5_000, 2)),
random_seed=rng,
)
.std(-1)
.mean(0)
)

assert not np.isclose(ab[0], ab[1], rtol=0.3), "Test is not informative"
np.testing.assert_allclose([a, b], ab, rtol=0.1)


class TestMvStudentTChol(BaseTestDistributionRandom):
pymc_dist = pm.MvStudentT
Expand Down

0 comments on commit fc7ed02

Please sign in to comment.