Skip to content

Commit

Permalink
WIP Remove Ellipses, fix shape handling.
Browse files Browse the repository at this point in the history
  • Loading branch information
twiecki committed May 14, 2021
1 parent 2d12dd7 commit 0f00a31
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 54 deletions.
68 changes: 47 additions & 21 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from aesara.tensor.shape import SpecifyShape, specify_shape

from pymc3.aesaraf import change_rv_size, pandas_to_array
from pymc3.exceptions import ShapeError
from pymc3.distributions import _logcdf, _logp
from pymc3.util import UNSET, get_repr_for_variable
from pymc3.vartypes import string_types
Expand Down Expand Up @@ -343,32 +344,57 @@ def dist(

shape, _, size = _validate_shape_dims_size(shape=shape, size=size)
assert_shape = None
batch_shape = ()
ndim_supp = cls.rv_op.ndim_supp

# Create the RV without specifying size or testval.
# The size will be expanded later (if necessary) and only then the testval fits.
rv_native = cls.rv_op(*dist_params, size=None, **kwargs)

if shape is None and size is None:
size = ()
elif shape is not None:
# SpecifyShape is automatically applied for symbolic and non-Ellipsis shapes
if shape is not None and size is None:
# A shape was passed, but internally we need `size`. Here we slice
# `shape` into `size` according to the ndim_support.
# Also, we determine the expected shape for an automatic SpecifyShape.
if isinstance(shape, Variable):
assert_shape = shape
size = ()
batch_shape = ()
else:
if Ellipsis in shape:
size = tuple(shape[:-1])
else:
size = tuple(shape[: len(shape) - rv_native.ndim])
assert_shape = shape
# no-op conditions:
# `elif size is not None` (User already specified how to expand the RV)
# `else` (Unreachable)

if size:
rv_out = change_rv_size(rv_var=rv_native, new_size=size, expand=True)
batch_shape = tuple(shape[: len(shape) - ndim_supp])
assert_shape = shape
else:
batch_shape = size

# Create the RV with a `size` right away.
rv_out = cls.rv_op(*dist_params, size=batch_shape, **kwargs)

# There are three ndims at play:
# 1. ndim_supp (inherent to the RV Op)
# 2. ndim_batch (dimensions in addition to ndim_supp)
# 3. ndim_actual (can be different than ndim_supp+ndim_batch in some edge cases)
ndim_actual = rv_out.ndim
if shape is not None:
# The number of batch dimensions is simply the length of shape
ndim_batch = len(tuple(shape)) - ndim_supp
elif size is not None:
ndim_batch = len(tuple(size))
else:
rv_out = rv_native
ndim_batch = ndim_actual - ndim_supp

# This is rare, but happens, for example, with MvNormal(np.ones((2, 3)), np.eye(3), size=2).
if ndim_actual != ndim_batch + ndim_supp and size is None:
# There are shape dimensions that go beyond what's implied by the RV parameters.
# Recreate the RV without passing `size`, this time creating batch dimensions with `change_rv_size`.
rv_out = change_rv_size(
rv_var=cls.rv_op(*dist_params, size=None, **kwargs),
#new_size=shape[:-rv_out.ndim],
new_size=shape[:-ndim_batch] if ndim_batch > 0 else (),
expand=True
)
if not rv_out.ndim == ndim_batch + ndim_supp:
raise ShapeError(f"Resized RV does not have the expected dimensionality.", actual=rv_out.ndim, expected=ndim_batch + ndim_supp)

# Now deal with edge cases where RV.ndim != len(size) + ndim_supp.
if size is not None and ndim_actual != len(tuple(size)) + ndim_supp:
warnings.warn(
f"Attention. You may have expected a ({len(tuple(size))}+{ndim_supp})-dimensional RV, but the resulting RV will be {ndim_actual}-dimensional.",
UserWarning
)

if assert_shape is not None:
rv_out = specify_shape(rv_out, shape=assert_shape)
Expand Down
81 changes: 48 additions & 33 deletions pymc3/tests/test_shape_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import pymc3 as pm

from pymc3.exceptions import ShapeError
from pymc3.distributions.distribution import _validate_shape_dims_size
from pymc3.distributions.shape_utils import (
broadcast_dist_samples_shape,
Expand Down Expand Up @@ -232,9 +233,9 @@ class TestShapeDimsSize:
[
"implicit",
"shape",
"shape...",
# "shape...",
"dims",
"dims...",
# "dims...",
"size",
],
)
Expand Down Expand Up @@ -265,28 +266,29 @@ def test_param_and_batch_shape_combos(
if parametrization == "implicit":
rv = pm.Normal("rv", mu=mu).shape == param_shape
else:
expected_shape = batch_shape + param_shape
if parametrization == "shape":
rv = pm.Normal("rv", mu=mu, shape=batch_shape + param_shape)
assert rv.eval().shape == batch_shape + param_shape
elif parametrization == "shape...":
rv = pm.Normal("rv", mu=mu, shape=(*batch_shape, ...))
assert rv.eval().shape == batch_shape + param_shape
assert rv.eval().shape == expected_shape
# elif parametrization == "shape...":
# rv = pm.Normal("rv", mu=mu, shape=(*batch_shape, ...))
# assert rv.eval().shape == batch_shape + param_shape
elif parametrization == "dims":
rv = pm.Normal("rv", mu=mu, dims=batch_dims + param_dims)
assert rv.eval().shape == batch_shape + param_shape
elif parametrization == "dims...":
rv = pm.Normal("rv", mu=mu, dims=(*batch_dims, ...))
n_size = len(batch_shape)
n_implied = len(param_shape)
ndim = n_size + n_implied
assert len(pmodel.RV_dims["rv"]) == ndim, pmodel.RV_dims
assert len(pmodel.RV_dims["rv"][:n_size]) == len(batch_dims)
assert len(pmodel.RV_dims["rv"][n_size:]) == len(param_dims)
if n_implied > 0:
assert pmodel.RV_dims["rv"][-1] is None
assert rv.eval().shape == expected_shape
# elif parametrization == "dims...":
# rv = pm.Normal("rv", mu=mu, dims=(*batch_dims, ...))
# n_size = len(batch_shape)
# n_implied = len(param_shape)
# ndim = n_size + n_implied
# assert len(pmodel.RV_dims["rv"]) == ndim, pmodel.RV_dims
# assert len(pmodel.RV_dims["rv"][:n_size]) == len(batch_dims)
# assert len(pmodel.RV_dims["rv"][n_size:]) == len(param_dims)
# if n_implied > 0:
# assert pmodel.RV_dims["rv"][-1] is None
elif parametrization == "size":
rv = pm.Normal("rv", mu=mu, size=batch_shape)
assert rv.eval().shape == batch_shape + param_shape
rv = pm.Normal("rv", mu=mu, size=batch_shape + param_shape)
assert rv.eval().shape == expected_shape
else:
raise NotImplementedError("Invalid test case parametrization.")

Expand Down Expand Up @@ -352,11 +354,28 @@ def test_dist_api_works(self):
pm.Normal.dist(mu=mu, dims=("town",))
assert pm.Normal.dist(mu=mu, shape=(3,)).eval().shape == (3,)
assert pm.Normal.dist(mu=mu, shape=(5, 3)).eval().shape == (5, 3)
assert pm.Normal.dist(mu=mu, shape=(7, ...)).eval().shape == (7, 3)
assert pm.Normal.dist(mu=mu, size=(4,)).eval().shape == (4, 3)
# assert pm.Normal.dist(mu=mu, shape=(7, ...)).eval().shape == (7, 3)
assert pm.Normal.dist(mu=mu, size=(3,)).eval().shape == (3,)
assert pm.Normal.dist(mu=mu, size=(4, 3)).eval().shape == (4, 3)

def test_mvnormal_ndim_warning(self):
with pytest.warns(None):
# parameters add 1 batch dimension (4), shape adds another (5)
rv = pm.MvNormal.dist(mu=np.ones((4, 3)), cov=np.eye(3), shape=(5, 4, 3))
assert rv.ndim == 3
assert tuple(rv.shape.eval()) == (5, 4, 3)

# with pytest.warns(None):
# rv = pm.MvNormal.dist(mu=np.ones((4, 3, 2)), cov=np.eye(2), shape=(6, 5, ...))
# assert rv.ndim == 5
# assert tuple(rv.shape.eval()) == (6, 5, 4, 3, 2)

# When using `size` the API behaves like Aesara/NumPy
with pytest.warns(UserWarning, match=r"You may have expected a \(2\+1\)-dimensional RV, but the resulting RV will be 5-dimensional"):
pm.MvNormal.dist(mu=np.ones((5, 4, 3)), cov=np.eye(3), size=(5, 4))

def test_auto_assert_shape(self):
with pytest.raises(AssertionError, match="will never match"):
with pytest.raises(ShapeError, match="does not have the expected dimensionality"):
pm.Normal.dist(mu=[1, 2], shape=[])

mu = at.vector(name="mu_input")
Expand All @@ -371,11 +390,7 @@ def test_auto_assert_shape(self):
s = at.vector(dtype="int32")
rv = pm.Uniform.dist(2, [4, 5], shape=s)
f = aesara.function([s], rv, mode=aesara.Mode("py"))
f(
[
2,
]
)
f([2,])
with pytest.raises(
AssertionError,
match=r"Got 1 dimensions \(shape \(2,\)\), expected 2 dimensions with shape \(3, 4\).",
Expand Down Expand Up @@ -413,9 +428,9 @@ def test_invalid_flavors(self):
_validate_shape_dims_size(size="notasize")

# invalid ellipsis positions
with pytest.raises(ValueError, match="may only appear in the last position"):
_validate_shape_dims_size(shape=(3, ..., 2))
with pytest.raises(ValueError, match="may only appear in the last position"):
_validate_shape_dims_size(dims=(..., "town"))
with pytest.raises(ValueError, match="cannot contain"):
_validate_shape_dims_size(size=(3, ...))
# with pytest.raises(ValueError, match="may only appear in the last position"):
# _validate_shape_dims_size(shape=(3, ..., 2))
# with pytest.raises(ValueError, match="may only appear in the last position"):
# _validate_shape_dims_size(dims=(..., "town"))
# with pytest.raises(ValueError, match="cannot contain"):
# _validate_shape_dims_size(size=(3, ...))

0 comments on commit 0f00a31

Please sign in to comment.