diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 7001cb6b4af..4573865189c 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -45,6 +45,7 @@ This new version of `Theano-PyMC` comes with an experimental JAX backend which, - Enabled the `Multinomial` distribution to handle batch sizes that have more than 2 dimensions. [#4169](https://github.com/pymc-devs/pymc3/pull/4169) - Test model logp before starting any MCMC chains (see [#4116](https://github.com/pymc-devs/pymc3/issues/4116)) - Fix bug in `model.check_test_point` that caused the `test_point` argument to be ignored. (see [PR #4211](https://github.com/pymc-devs/pymc3/pull/4211#issuecomment-727142721)) +- Refactored MvNormal.random method with better handling of sample, batch and event shapes. [#4207](https://github.com/pymc-devs/pymc3/pull/4207) ### Documentation - Added a new notebook demonstrating how to incorporate sampling from a conjugate Dirichlet-multinomial posterior density in conjunction with other step methods (see [#4199](https://github.com/pymc-devs/pymc3/pull/4199)). diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index 79831784b08..2d4c433d288 100755 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -36,7 +36,7 @@ from .continuous import ChiSquared, Normal from .special import gammaln, multigammaln from .dist_math import bound, logpow, factln -from .shape_utils import to_tuple +from .shape_utils import to_tuple, broadcast_dist_samples_to from ..math import kron_dot, kron_diag, kron_solve_lower, kronecker @@ -250,58 +250,36 @@ def random(self, point=None, size=None): ------- array """ - if size is None: - size = tuple() - else: - if not isinstance(size, tuple): - try: - size = tuple(size) - except TypeError: - size = (size,) + size = to_tuple(size) - if self._cov_type == "cov": - mu, cov = draw_values([self.mu, self.cov], point=point, size=size) - if mu.shape[-1] != cov.shape[-1]: - raise ValueError("Shapes for mu and cov don't match") + param_attribute = getattr(self, "chol_cov" if self._cov_type == "chol" else self._cov_type) + mu, param = draw_values([self.mu, param_attribute], point=point, size=size) - try: - dist = stats.multivariate_normal(mean=mu, cov=cov, allow_singular=True) - except ValueError: - size += (mu.shape[-1],) - return np.nan * np.zeros(size) - return dist.rvs(size) - elif self._cov_type == "chol": - mu, chol = draw_values([self.mu, self.chol_cov], point=point, size=size) - if size and mu.ndim == len(size) and mu.shape == size: - mu = mu[..., np.newaxis] - if mu.shape[-1] != chol.shape[-1] and mu.shape[-1] != 1: - raise ValueError("Shapes for mu and chol don't match") - broadcast_shape = np.broadcast(np.empty(mu.shape[:-1]), np.empty(chol.shape[:-2])).shape - - mu = np.broadcast_to(mu, broadcast_shape + (chol.shape[-1],)) - chol = np.broadcast_to(chol, broadcast_shape + chol.shape[-2:]) - # If mu and chol were fixed by the point, only the standard normal - # should change - if mu.shape[: len(size)] != size: - std_norm_shape = size + mu.shape - else: - std_norm_shape = mu.shape - standard_normal = np.random.standard_normal(std_norm_shape) - return mu + np.einsum("...ij,...j->...i", chol, standard_normal) - else: - mu, tau = draw_values([self.mu, self.tau], point=point, size=size) - if mu.shape[-1] != tau[0].shape[-1]: - raise ValueError("Shapes for mu and tau don't match") + dist_shape = to_tuple(self.shape) + output_shape = size + dist_shape - size += (mu.shape[-1],) - try: - chol = linalg.cholesky(tau, lower=True) - except linalg.LinAlgError: - return np.nan * np.zeros(size) + # Simple, there can be only be 1 batch dimension, only available from `mu`. + # Insert it into `param` before events, if there is a sample shape in front. + if param.ndim > 2 and dist_shape[:-1]: + param = param.reshape(size + (1,) + param.shape[-2:]) + + mu = broadcast_dist_samples_to(to_shape=output_shape, samples=[mu], size=size)[0] + param = np.broadcast_to(param, shape=output_shape + dist_shape[-1:]) + + assert mu.shape == output_shape + assert param.shape == output_shape + dist_shape[-1:] + + if self._cov_type == "cov": + chol = np.linalg.cholesky(param) + elif self._cov_type == "chol": + chol = param + else: # tau -> chol -> swapaxes (chol, -1, -2) -> inv ... + lower_chol = np.linalg.cholesky(param) + upper_chol = np.swapaxes(lower_chol, -1, -2) + chol = np.linalg.inv(upper_chol) - standard_normal = np.random.standard_normal(size) - transformed = linalg.solve_triangular(chol, standard_normal.T, lower=True) - return mu + transformed.T + standard_normal = np.random.standard_normal(output_shape) + return mu + np.einsum("...ij,...j->...i", chol, standard_normal) def logp(self, value): """ @@ -399,13 +377,13 @@ def random(self, point=None, size=None): nu, mu = draw_values([self.nu, self.mu], point=point, size=size) if self._cov_type == "cov": (cov,) = draw_values([self.cov], point=point, size=size) - dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov) + dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov, shape=self.shape) elif self._cov_type == "tau": (tau,) = draw_values([self.tau], point=point, size=size) - dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau) + dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau, shape=self.shape) else: (chol,) = draw_values([self.chol_cov], point=point, size=size) - dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol) + dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol, shape=self.shape) samples = dist.random(point, size) @@ -1915,6 +1893,7 @@ def random(self, point=None, size=None): """ # Expand params into terms MvNormal can understand to force consistency self._setup_random() + self.mv_params["shape"] = self.shape dist = MvNormal.dist(**self.mv_params) return dist.random(point, size) diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index b15eb9b00e5..4daa1857b85 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import pytest import numpy as np import numpy.testing as npt @@ -27,6 +28,7 @@ draw_values, _DrawValuesContext, _DrawValuesContextBlocker, + to_tuple, ) from .helpers import SeededTest from .test_distributions import ( @@ -1544,3 +1546,112 @@ def test_Triangular( prior_samples=prior_samples, ) assert prior["target"].shape == (prior_samples,) + shape + + +def generate_shapes(include_params=False, xfail=False): + # fmt: off + mudim_as_event = [ + [None, 1, 3, 10, (10, 3), 100], + [(3,)], + [(1,), (3,)], + ["cov", "chol", "tau"] + ] + # fmt: on + mudim_as_dist = [ + [None, 1, 3, 10, (10, 3), 100], + [(10, 3)], + [(1,), (3,), (1, 1), (1, 3), (10, 1), (10, 3)], + ["cov", "chol", "tau"], + ] + if not include_params: + del mudim_as_event[-1] + del mudim_as_dist[-1] + data = itertools.chain(itertools.product(*mudim_as_event), itertools.product(*mudim_as_dist)) + if xfail: + data = list(data) + for index in range(len(data)): + if data[index][0] in (None, 1): + data[index] = pytest.param( + *data[index], marks=pytest.mark.xfail(reason="wait for PR #4214") + ) + return data + + +class TestMvNormal(SeededTest): + @pytest.mark.parametrize( + ["sample_shape", "dist_shape", "mu_shape", "param"], + generate_shapes(include_params=True, xfail=False), + ids=str, + ) + def test_with_np_arrays(self, sample_shape, dist_shape, mu_shape, param): + dist = pm.MvNormal.dist(mu=np.ones(mu_shape), **{param: np.eye(3)}, shape=dist_shape) + output_shape = to_tuple(sample_shape) + dist_shape + assert dist.random(size=sample_shape).shape == output_shape + + @pytest.mark.parametrize( + ["sample_shape", "dist_shape", "mu_shape"], + generate_shapes(include_params=False, xfail=True), + ids=str, + ) + def test_with_chol_rv(self, sample_shape, dist_shape, mu_shape): + with pm.Model() as model: + mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape) + sd_dist = pm.Exponential.dist(1.0, shape=3) + chol, corr, stds = pm.LKJCholeskyCov( + "chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True + ) + mv = pm.MvNormal("mv", mu, chol=chol, shape=dist_shape) + prior = pm.sample_prior_predictive(samples=sample_shape) + + assert prior["mv"].shape == to_tuple(sample_shape) + dist_shape + + @pytest.mark.parametrize( + ["sample_shape", "dist_shape", "mu_shape"], + generate_shapes(include_params=False, xfail=True), + ids=str, + ) + def test_with_cov_rv(self, sample_shape, dist_shape, mu_shape): + with pm.Model() as model: + mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape) + sd_dist = pm.Exponential.dist(1.0, shape=3) + chol, corr, stds = pm.LKJCholeskyCov( + "chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True + ) + mv = pm.MvNormal("mv", mu, cov=pm.math.dot(chol, chol.T), shape=dist_shape) + prior = pm.sample_prior_predictive(samples=sample_shape) + + assert prior["mv"].shape == to_tuple(sample_shape) + dist_shape + + def test_issue_3758(self): + np.random.seed(42) + ndim = 50 + with pm.Model() as model: + a = pm.Normal("a", sigma=100, shape=ndim) + b = pm.Normal("b", mu=a, sigma=1, shape=ndim) + c = pm.MvNormal("c", mu=a, chol=np.linalg.cholesky(np.eye(ndim)), shape=ndim) + d = pm.MvNormal("d", mu=a, cov=np.eye(ndim), shape=ndim) + samples = pm.sample_prior_predictive(1000) + + for var in "abcd": + assert not np.isnan(np.std(samples[var])) + + def test_issue_3829(self): + with pm.Model() as model: + x = pm.MvNormal("x", mu=np.zeros(5), cov=np.eye(5), shape=(2, 5)) + trace_pp = pm.sample_prior_predictive(50) + + assert np.shape(trace_pp["x"][0]) == (2, 5) + + def test_issue_3706(self): + N = 10 + Sigma = np.eye(2) + + with pm.Model() as model: + + X = pm.MvNormal("X", mu=np.zeros(2), cov=Sigma, shape=(N, 2)) + betas = pm.Normal("betas", 0, 1, shape=2) + y = pm.Deterministic("y", pm.math.dot(X, betas)) + + prior_pred = pm.sample_prior_predictive(1) + + assert prior_pred["X"].shape == (1, N, 2) diff --git a/pymc3/tests/test_mixture.py b/pymc3/tests/test_mixture.py index 27914b6d74f..9f016ad3332 100644 --- a/pymc3/tests/test_mixture.py +++ b/pymc3/tests/test_mixture.py @@ -368,7 +368,7 @@ def build_toy_dataset(N, K): ) ) chol.append(pm.expand_packed_triangular(D, packed_chol[i], lower=True)) - comp_dist.append(pm.MvNormal.dist(mu=mu[i], chol=chol[i])) + comp_dist.append(pm.MvNormal.dist(mu=mu[i], chol=chol[i], shape=D)) pm.Mixture("x_obs", pi, comp_dist, observed=X) with model: