Skip to content

Commit

Permalink
Fix MvNormal.random (#4207)
Browse files Browse the repository at this point in the history
* Fixed MvNormal.random method

* Added some comments

* Handled corner case when self.shape is not provided

* Fixed comments

* Considered the corner case of 'point' as well.

* Fixed MvNormal.random method

* Handled sample and batch dimensions in tau parametrization using numpy

Added batch dimensions to all parametrization

* Modified logic while inserting batch dimensions to parametrization

* Used shapes_utils.broadcast_dist_samples_to function for broadcasting

* Make pylint pass

* Make test passes 🤞 hopefully

* Modified logic and added tests

* Given a mention in release notes
  • Loading branch information
Sayam753 authored Nov 29, 2020
1 parent 4cccb46 commit 1c290ef
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 53 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ This new version of `Theano-PyMC` comes with an experimental JAX backend which,
- Enabled the `Multinomial` distribution to handle batch sizes that have more than 2 dimensions. [#4169](https://github.com/pymc-devs/pymc3/pull/4169)
- Test model logp before starting any MCMC chains (see [#4116](https://github.com/pymc-devs/pymc3/issues/4116))
- Fix bug in `model.check_test_point` that caused the `test_point` argument to be ignored. (see [PR #4211](https://github.com/pymc-devs/pymc3/pull/4211#issuecomment-727142721))
- Refactored MvNormal.random method with better handling of sample, batch and event shapes. [#4207](https://github.com/pymc-devs/pymc3/pull/4207)

### Documentation
- Added a new notebook demonstrating how to incorporate sampling from a conjugate Dirichlet-multinomial posterior density in conjunction with other step methods (see [#4199](https://github.com/pymc-devs/pymc3/pull/4199)).
Expand Down
83 changes: 31 additions & 52 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .continuous import ChiSquared, Normal
from .special import gammaln, multigammaln
from .dist_math import bound, logpow, factln
from .shape_utils import to_tuple
from .shape_utils import to_tuple, broadcast_dist_samples_to
from ..math import kron_dot, kron_diag, kron_solve_lower, kronecker


Expand Down Expand Up @@ -250,58 +250,36 @@ def random(self, point=None, size=None):
-------
array
"""
if size is None:
size = tuple()
else:
if not isinstance(size, tuple):
try:
size = tuple(size)
except TypeError:
size = (size,)
size = to_tuple(size)

if self._cov_type == "cov":
mu, cov = draw_values([self.mu, self.cov], point=point, size=size)
if mu.shape[-1] != cov.shape[-1]:
raise ValueError("Shapes for mu and cov don't match")
param_attribute = getattr(self, "chol_cov" if self._cov_type == "chol" else self._cov_type)
mu, param = draw_values([self.mu, param_attribute], point=point, size=size)

try:
dist = stats.multivariate_normal(mean=mu, cov=cov, allow_singular=True)
except ValueError:
size += (mu.shape[-1],)
return np.nan * np.zeros(size)
return dist.rvs(size)
elif self._cov_type == "chol":
mu, chol = draw_values([self.mu, self.chol_cov], point=point, size=size)
if size and mu.ndim == len(size) and mu.shape == size:
mu = mu[..., np.newaxis]
if mu.shape[-1] != chol.shape[-1] and mu.shape[-1] != 1:
raise ValueError("Shapes for mu and chol don't match")
broadcast_shape = np.broadcast(np.empty(mu.shape[:-1]), np.empty(chol.shape[:-2])).shape

mu = np.broadcast_to(mu, broadcast_shape + (chol.shape[-1],))
chol = np.broadcast_to(chol, broadcast_shape + chol.shape[-2:])
# If mu and chol were fixed by the point, only the standard normal
# should change
if mu.shape[: len(size)] != size:
std_norm_shape = size + mu.shape
else:
std_norm_shape = mu.shape
standard_normal = np.random.standard_normal(std_norm_shape)
return mu + np.einsum("...ij,...j->...i", chol, standard_normal)
else:
mu, tau = draw_values([self.mu, self.tau], point=point, size=size)
if mu.shape[-1] != tau[0].shape[-1]:
raise ValueError("Shapes for mu and tau don't match")
dist_shape = to_tuple(self.shape)
output_shape = size + dist_shape

size += (mu.shape[-1],)
try:
chol = linalg.cholesky(tau, lower=True)
except linalg.LinAlgError:
return np.nan * np.zeros(size)
# Simple, there can be only be 1 batch dimension, only available from `mu`.
# Insert it into `param` before events, if there is a sample shape in front.
if param.ndim > 2 and dist_shape[:-1]:
param = param.reshape(size + (1,) + param.shape[-2:])

mu = broadcast_dist_samples_to(to_shape=output_shape, samples=[mu], size=size)[0]
param = np.broadcast_to(param, shape=output_shape + dist_shape[-1:])

assert mu.shape == output_shape
assert param.shape == output_shape + dist_shape[-1:]

if self._cov_type == "cov":
chol = np.linalg.cholesky(param)
elif self._cov_type == "chol":
chol = param
else: # tau -> chol -> swapaxes (chol, -1, -2) -> inv ...
lower_chol = np.linalg.cholesky(param)
upper_chol = np.swapaxes(lower_chol, -1, -2)
chol = np.linalg.inv(upper_chol)

standard_normal = np.random.standard_normal(size)
transformed = linalg.solve_triangular(chol, standard_normal.T, lower=True)
return mu + transformed.T
standard_normal = np.random.standard_normal(output_shape)
return mu + np.einsum("...ij,...j->...i", chol, standard_normal)

def logp(self, value):
"""
Expand Down Expand Up @@ -399,13 +377,13 @@ def random(self, point=None, size=None):
nu, mu = draw_values([self.nu, self.mu], point=point, size=size)
if self._cov_type == "cov":
(cov,) = draw_values([self.cov], point=point, size=size)
dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov)
dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov, shape=self.shape)
elif self._cov_type == "tau":
(tau,) = draw_values([self.tau], point=point, size=size)
dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau)
dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau, shape=self.shape)
else:
(chol,) = draw_values([self.chol_cov], point=point, size=size)
dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol)
dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol, shape=self.shape)

samples = dist.random(point, size)

Expand Down Expand Up @@ -1915,6 +1893,7 @@ def random(self, point=None, size=None):
"""
# Expand params into terms MvNormal can understand to force consistency
self._setup_random()
self.mv_params["shape"] = self.shape
dist = MvNormal.dist(**self.mv_params)
return dist.random(point, size)

Expand Down
111 changes: 111 additions & 0 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import pytest
import numpy as np
import numpy.testing as npt
Expand All @@ -27,6 +28,7 @@
draw_values,
_DrawValuesContext,
_DrawValuesContextBlocker,
to_tuple,
)
from .helpers import SeededTest
from .test_distributions import (
Expand Down Expand Up @@ -1544,3 +1546,112 @@ def test_Triangular(
prior_samples=prior_samples,
)
assert prior["target"].shape == (prior_samples,) + shape


def generate_shapes(include_params=False, xfail=False):
# fmt: off
mudim_as_event = [
[None, 1, 3, 10, (10, 3), 100],
[(3,)],
[(1,), (3,)],
["cov", "chol", "tau"]
]
# fmt: on
mudim_as_dist = [
[None, 1, 3, 10, (10, 3), 100],
[(10, 3)],
[(1,), (3,), (1, 1), (1, 3), (10, 1), (10, 3)],
["cov", "chol", "tau"],
]
if not include_params:
del mudim_as_event[-1]
del mudim_as_dist[-1]
data = itertools.chain(itertools.product(*mudim_as_event), itertools.product(*mudim_as_dist))
if xfail:
data = list(data)
for index in range(len(data)):
if data[index][0] in (None, 1):
data[index] = pytest.param(
*data[index], marks=pytest.mark.xfail(reason="wait for PR #4214")
)
return data


class TestMvNormal(SeededTest):
@pytest.mark.parametrize(
["sample_shape", "dist_shape", "mu_shape", "param"],
generate_shapes(include_params=True, xfail=False),
ids=str,
)
def test_with_np_arrays(self, sample_shape, dist_shape, mu_shape, param):
dist = pm.MvNormal.dist(mu=np.ones(mu_shape), **{param: np.eye(3)}, shape=dist_shape)
output_shape = to_tuple(sample_shape) + dist_shape
assert dist.random(size=sample_shape).shape == output_shape

@pytest.mark.parametrize(
["sample_shape", "dist_shape", "mu_shape"],
generate_shapes(include_params=False, xfail=True),
ids=str,
)
def test_with_chol_rv(self, sample_shape, dist_shape, mu_shape):
with pm.Model() as model:
mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
sd_dist = pm.Exponential.dist(1.0, shape=3)
chol, corr, stds = pm.LKJCholeskyCov(
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
)
mv = pm.MvNormal("mv", mu, chol=chol, shape=dist_shape)
prior = pm.sample_prior_predictive(samples=sample_shape)

assert prior["mv"].shape == to_tuple(sample_shape) + dist_shape

@pytest.mark.parametrize(
["sample_shape", "dist_shape", "mu_shape"],
generate_shapes(include_params=False, xfail=True),
ids=str,
)
def test_with_cov_rv(self, sample_shape, dist_shape, mu_shape):
with pm.Model() as model:
mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
sd_dist = pm.Exponential.dist(1.0, shape=3)
chol, corr, stds = pm.LKJCholeskyCov(
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
)
mv = pm.MvNormal("mv", mu, cov=pm.math.dot(chol, chol.T), shape=dist_shape)
prior = pm.sample_prior_predictive(samples=sample_shape)

assert prior["mv"].shape == to_tuple(sample_shape) + dist_shape

def test_issue_3758(self):
np.random.seed(42)
ndim = 50
with pm.Model() as model:
a = pm.Normal("a", sigma=100, shape=ndim)
b = pm.Normal("b", mu=a, sigma=1, shape=ndim)
c = pm.MvNormal("c", mu=a, chol=np.linalg.cholesky(np.eye(ndim)), shape=ndim)
d = pm.MvNormal("d", mu=a, cov=np.eye(ndim), shape=ndim)
samples = pm.sample_prior_predictive(1000)

for var in "abcd":
assert not np.isnan(np.std(samples[var]))

def test_issue_3829(self):
with pm.Model() as model:
x = pm.MvNormal("x", mu=np.zeros(5), cov=np.eye(5), shape=(2, 5))
trace_pp = pm.sample_prior_predictive(50)

assert np.shape(trace_pp["x"][0]) == (2, 5)

def test_issue_3706(self):
N = 10
Sigma = np.eye(2)

with pm.Model() as model:

X = pm.MvNormal("X", mu=np.zeros(2), cov=Sigma, shape=(N, 2))
betas = pm.Normal("betas", 0, 1, shape=2)
y = pm.Deterministic("y", pm.math.dot(X, betas))

prior_pred = pm.sample_prior_predictive(1)

assert prior_pred["X"].shape == (1, N, 2)
2 changes: 1 addition & 1 deletion pymc3/tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def build_toy_dataset(N, K):
)
)
chol.append(pm.expand_packed_triangular(D, packed_chol[i], lower=True))
comp_dist.append(pm.MvNormal.dist(mu=mu[i], chol=chol[i]))
comp_dist.append(pm.MvNormal.dist(mu=mu[i], chol=chol[i], shape=D))

pm.Mixture("x_obs", pi, comp_dist, observed=X)
with model:
Expand Down

0 comments on commit 1c290ef

Please sign in to comment.