Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MvNormal.random #4207

Merged
merged 13 commits into from
Nov 29, 2020
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ This new version of `Theano-PyMC` comes with an experimental JAX backend which,
- Enabled the `Multinomial` distribution to handle batch sizes that have more than 2 dimensions. [#4169](https://github.com/pymc-devs/pymc3/pull/4169)
- Test model logp before starting any MCMC chains (see [#4116](https://github.com/pymc-devs/pymc3/issues/4116))
- Fix bug in `model.check_test_point` that caused the `test_point` argument to be ignored. (see [PR #4211](https://github.com/pymc-devs/pymc3/pull/4211#issuecomment-727142721))
- Refactored MvNormal.random method with better handling of sample, batch and event shapes. [#4207](https://github.com/pymc-devs/pymc3/pull/4207)

### Documentation
- Added a new notebook demonstrating how to incorporate sampling from a conjugate Dirichlet-multinomial posterior density in conjunction with other step methods (see [#4199](https://github.com/pymc-devs/pymc3/pull/4199)).
Expand Down
83 changes: 31 additions & 52 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .continuous import ChiSquared, Normal
from .special import gammaln, multigammaln
from .dist_math import bound, logpow, factln
from .shape_utils import to_tuple
from .shape_utils import to_tuple, broadcast_dist_samples_to
from ..math import kron_dot, kron_diag, kron_solve_lower, kronecker


Expand Down Expand Up @@ -250,58 +250,36 @@ def random(self, point=None, size=None):
-------
array
"""
if size is None:
size = tuple()
else:
if not isinstance(size, tuple):
try:
size = tuple(size)
except TypeError:
size = (size,)
size = to_tuple(size)

if self._cov_type == "cov":
mu, cov = draw_values([self.mu, self.cov], point=point, size=size)
if mu.shape[-1] != cov.shape[-1]:
raise ValueError("Shapes for mu and cov don't match")
param_attribute = getattr(self, "chol_cov" if self._cov_type == "chol" else self._cov_type)
mu, param = draw_values([self.mu, param_attribute], point=point, size=size)

try:
dist = stats.multivariate_normal(mean=mu, cov=cov, allow_singular=True)
except ValueError:
size += (mu.shape[-1],)
return np.nan * np.zeros(size)
return dist.rvs(size)
elif self._cov_type == "chol":
mu, chol = draw_values([self.mu, self.chol_cov], point=point, size=size)
if size and mu.ndim == len(size) and mu.shape == size:
mu = mu[..., np.newaxis]
if mu.shape[-1] != chol.shape[-1] and mu.shape[-1] != 1:
raise ValueError("Shapes for mu and chol don't match")
broadcast_shape = np.broadcast(np.empty(mu.shape[:-1]), np.empty(chol.shape[:-2])).shape

mu = np.broadcast_to(mu, broadcast_shape + (chol.shape[-1],))
chol = np.broadcast_to(chol, broadcast_shape + chol.shape[-2:])
# If mu and chol were fixed by the point, only the standard normal
# should change
if mu.shape[: len(size)] != size:
std_norm_shape = size + mu.shape
else:
std_norm_shape = mu.shape
standard_normal = np.random.standard_normal(std_norm_shape)
return mu + np.einsum("...ij,...j->...i", chol, standard_normal)
else:
mu, tau = draw_values([self.mu, self.tau], point=point, size=size)
if mu.shape[-1] != tau[0].shape[-1]:
raise ValueError("Shapes for mu and tau don't match")
dist_shape = to_tuple(self.shape)
output_shape = size + dist_shape

size += (mu.shape[-1],)
try:
chol = linalg.cholesky(tau, lower=True)
except linalg.LinAlgError:
return np.nan * np.zeros(size)
# Simple, there can be only be 1 batch dimension, only available from `mu`.
# Insert it into `param` before events, if there is a sample shape in front.
if param.ndim > 2 and dist_shape[:-1]:
param = param.reshape(size + (1,) + param.shape[-2:])

mu = broadcast_dist_samples_to(to_shape=output_shape, samples=[mu], size=size)[0]
param = np.broadcast_to(param, shape=output_shape + dist_shape[-1:])

assert mu.shape == output_shape
assert param.shape == output_shape + dist_shape[-1:]

if self._cov_type == "cov":
Sayam753 marked this conversation as resolved.
Show resolved Hide resolved
chol = np.linalg.cholesky(param)
elif self._cov_type == "chol":
chol = param
else: # tau -> chol -> swapaxes (chol, -1, -2) -> inv ...
lower_chol = np.linalg.cholesky(param)
upper_chol = np.swapaxes(lower_chol, -1, -2)
chol = np.linalg.inv(upper_chol)

standard_normal = np.random.standard_normal(size)
transformed = linalg.solve_triangular(chol, standard_normal.T, lower=True)
return mu + transformed.T
standard_normal = np.random.standard_normal(output_shape)
return mu + np.einsum("...ij,...j->...i", chol, standard_normal)

def logp(self, value):
"""
Expand Down Expand Up @@ -399,13 +377,13 @@ def random(self, point=None, size=None):
nu, mu = draw_values([self.nu, self.mu], point=point, size=size)
if self._cov_type == "cov":
(cov,) = draw_values([self.cov], point=point, size=size)
dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov)
dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov, shape=self.shape)
elif self._cov_type == "tau":
(tau,) = draw_values([self.tau], point=point, size=size)
dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau)
dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau, shape=self.shape)
else:
(chol,) = draw_values([self.chol_cov], point=point, size=size)
dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol)
dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol, shape=self.shape)

samples = dist.random(point, size)

Expand Down Expand Up @@ -1915,6 +1893,7 @@ def random(self, point=None, size=None):
"""
# Expand params into terms MvNormal can understand to force consistency
self._setup_random()
self.mv_params["shape"] = self.shape
dist = MvNormal.dist(**self.mv_params)
return dist.random(point, size)

Expand Down
111 changes: 111 additions & 0 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import pytest
import numpy as np
import numpy.testing as npt
Expand All @@ -27,6 +28,7 @@
draw_values,
_DrawValuesContext,
_DrawValuesContextBlocker,
to_tuple,
)
from .helpers import SeededTest
from .test_distributions import (
Expand Down Expand Up @@ -1544,3 +1546,112 @@ def test_Triangular(
prior_samples=prior_samples,
)
assert prior["target"].shape == (prior_samples,) + shape


def generate_shapes(include_params=False, xfail=False):
# fmt: off
mudim_as_event = [
[None, 1, 3, 10, (10, 3), 100],
[(3,)],
[(1,), (3,)],
["cov", "chol", "tau"]
]
# fmt: on
mudim_as_dist = [
[None, 1, 3, 10, (10, 3), 100],
[(10, 3)],
[(1,), (3,), (1, 1), (1, 3), (10, 1), (10, 3)],
["cov", "chol", "tau"],
]
if not include_params:
del mudim_as_event[-1]
del mudim_as_dist[-1]
data = itertools.chain(itertools.product(*mudim_as_event), itertools.product(*mudim_as_dist))
if xfail:
data = list(data)
for index in range(len(data)):
if data[index][0] in (None, 1):
data[index] = pytest.param(
*data[index], marks=pytest.mark.xfail(reason="wait for PR #4214")
)
return data
Sayam753 marked this conversation as resolved.
Show resolved Hide resolved


class TestMvNormal(SeededTest):
@pytest.mark.parametrize(
["sample_shape", "dist_shape", "mu_shape", "param"],
generate_shapes(include_params=True, xfail=False),
ids=str,
)
def test_with_np_arrays(self, sample_shape, dist_shape, mu_shape, param):
dist = pm.MvNormal.dist(mu=np.ones(mu_shape), **{param: np.eye(3)}, shape=dist_shape)
output_shape = to_tuple(sample_shape) + dist_shape
assert dist.random(size=sample_shape).shape == output_shape

@pytest.mark.parametrize(
["sample_shape", "dist_shape", "mu_shape"],
generate_shapes(include_params=False, xfail=True),
ids=str,
)
def test_with_chol_rv(self, sample_shape, dist_shape, mu_shape):
with pm.Model() as model:
mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
sd_dist = pm.Exponential.dist(1.0, shape=3)
chol, corr, stds = pm.LKJCholeskyCov(
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
)
mv = pm.MvNormal("mv", mu, chol=chol, shape=dist_shape)
prior = pm.sample_prior_predictive(samples=sample_shape)

assert prior["mv"].shape == to_tuple(sample_shape) + dist_shape

@pytest.mark.parametrize(
["sample_shape", "dist_shape", "mu_shape"],
generate_shapes(include_params=False, xfail=True),
ids=str,
)
def test_with_cov_rv(self, sample_shape, dist_shape, mu_shape):
with pm.Model() as model:
mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
sd_dist = pm.Exponential.dist(1.0, shape=3)
chol, corr, stds = pm.LKJCholeskyCov(
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
)
mv = pm.MvNormal("mv", mu, cov=pm.math.dot(chol, chol.T), shape=dist_shape)
prior = pm.sample_prior_predictive(samples=sample_shape)

assert prior["mv"].shape == to_tuple(sample_shape) + dist_shape

def test_issue_3758(self):
np.random.seed(42)
ndim = 50
with pm.Model() as model:
a = pm.Normal("a", sigma=100, shape=ndim)
b = pm.Normal("b", mu=a, sigma=1, shape=ndim)
c = pm.MvNormal("c", mu=a, chol=np.linalg.cholesky(np.eye(ndim)), shape=ndim)
d = pm.MvNormal("d", mu=a, cov=np.eye(ndim), shape=ndim)
samples = pm.sample_prior_predictive(1000)

for var in "abcd":
assert not np.isnan(np.std(samples[var]))

def test_issue_3829(self):
with pm.Model() as model:
x = pm.MvNormal("x", mu=np.zeros(5), cov=np.eye(5), shape=(2, 5))
trace_pp = pm.sample_prior_predictive(50)

assert np.shape(trace_pp["x"][0]) == (2, 5)

def test_issue_3706(self):
N = 10
Sigma = np.eye(2)

with pm.Model() as model:

X = pm.MvNormal("X", mu=np.zeros(2), cov=Sigma, shape=(N, 2))
betas = pm.Normal("betas", 0, 1, shape=2)
y = pm.Deterministic("y", pm.math.dot(X, betas))

prior_pred = pm.sample_prior_predictive(1)

assert prior_pred["X"].shape == (1, N, 2)
2 changes: 1 addition & 1 deletion pymc3/tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def build_toy_dataset(N, K):
)
)
chol.append(pm.expand_packed_triangular(D, packed_chol[i], lower=True))
comp_dist.append(pm.MvNormal.dist(mu=mu[i], chol=chol[i]))
comp_dist.append(pm.MvNormal.dist(mu=mu[i], chol=chol[i], shape=D))

pm.Mixture("x_obs", pi, comp_dist, observed=X)
with model:
Expand Down