Skip to content

Commit

Permalink
Implement shape/dims/size API
Browse files Browse the repository at this point in the history
Dims with Ellipsis are not yet implemented.
Some tests were refactored because size is now implemented more consistently.
  • Loading branch information
michaelosthege committed Apr 8, 2021
1 parent 116cdf3 commit fa43a1a
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 82 deletions.
114 changes: 104 additions & 10 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License.
import contextvars
import inspect
import logging
import multiprocessing
import sys
import types
import warnings

from abc import ABCMeta
from copy import copy
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Sequence, Union

import dill

Expand All @@ -35,6 +36,8 @@
import aesara.graph.basic
import aesara.tensor as at

from pymc3.aesaraf import change_rv_size
from pymc3.exceptions import ShapeError
from pymc3.util import UNSET, get_repr_for_variable
from pymc3.vartypes import string_types

Expand All @@ -46,6 +49,8 @@
"NoDistribution",
]

_log = logging.getLogger(__file__)

vectorized_ppc = contextvars.ContextVar(
"vectorized_ppc", default=None
) # type: contextvars.ContextVar[Optional[Callable]]
Expand Down Expand Up @@ -122,6 +127,19 @@ def logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
return new_cls


def _valid_ellipsis_position(
items: Union[
None,
Sequence[Union[str, type(Ellipsis)]],
Sequence[Union[int, type(Ellipsis)]],
]
):
if items is not None and Ellipsis in items:
if any(i == Ellipsis for i in items[:-1]):
return False
return True


class Distribution(metaclass=DistributionMeta):
"""Statistical distribution"""

Expand All @@ -141,26 +159,102 @@ def __new__(cls, name, *args, **kwargs):
"for a standalone distribution."
)

rng = kwargs.pop("rng", None)
if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")

# Pop out PyMC3-related kwargs so only the disttribution kwargs remain
rng = kwargs.pop("rng", None)
if rng is None:
rng = model.default_rng

if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")

data = kwargs.pop("observed", None)

total_size = kwargs.pop("total_size", None)
testval = kwargs.pop("testval", None)
transform = kwargs.pop("transform", UNSET)

shape = kwargs.pop("shape", None)
dims = kwargs.pop("dims", None)
size = kwargs.pop("size", None)

# Raise on unsupported parametrization
if shape is not None and dims is not None:
raise ValueError("Passing both `shape` ({shape}) and `dims` ({dims}) is not supported!")
if dims is not None and size is not None:
raise ValueError("Passing both `dims` ({dims}) and `size` ({size}) is not supported!")
if shape is not None and size is not None:
raise ValueError("Passing both `shape` ({shape}) and `size` ({size}) is not supported!")

# Warn about discouraged parametrization
if shape is not None and not isinstance(shape, (list, tuple)):
warnings.warn("The `shape` parameter should be a list or tuple.", UserWarning)
shape = (shape,)
if dims is not None and not isinstance(dims, (list, tuple)):
warnings.warn("The `dims` parameter should be a list or tuple.", UserWarning)
dims = (dims,)
if size is not None and not isinstance(size, (list, tuple)):
warnings.warn("The `size` parameter should be a list or tuple.", UserWarning)
size = (size,)

if size is not None and Ellipsis in size:
raise ValueError("The `size` parameter cannot contain an Ellipsis. Actual: {size}")
if not _valid_ellipsis_position(shape):
raise ValueError(
f"Ellipsis in `shape` may only appear in the last position. Actual: {shape}"
)
if not _valid_ellipsis_position(dims):
raise ValueError(
f"Ellipsis in `dims` may only appear in the last position. Actual: {dims}"
)

if "shape" in kwargs:
raise DeprecationWarning("The `shape` keyword is deprecated; use `size`.")
# Create the RV without specifying size or testval.
# The size will be expanded later (if necessary) and only then the testval fits.
rv_native = cls.dist(*args, rng=rng, testval=None, size=None, **kwargs)
implied_ndim = rv_native.ndim
implied_dims = ... # TODO: infer dimension names from Variables!

if shape is None and dims is None and size is None:
size = ()
elif size is not None:
# User already specified how to expand the RV
pass
elif shape is not None:
# Infer size from shape
if Ellipsis in shape:
size = tuple(shape[:-1])
else:
size = tuple(shape[: len(shape) - implied_ndim])
if size:
warnings.warn(
f"The specified shape {shape} has {len(shape)} dimensions. "
f"This is more than the {implied_ndim} dimensions implied by the distribution parameters. "
f"To replicate the RV beyond its implied dimensionality use `size={size}` instead.",
UserWarning,
)
elif dims is not None:
# Infer size from dims (and coords)
if Ellipsis in dims:
raise NotImplementedError("Ellipsis-based dims inference is not implemented.")
dims = tuple(dims[:-1]) + implied_dims
dimshape = tuple(len(model.coords[dname]) for dname in dims)
size = tuple(dimshape[: len(dimshape) - implied_ndim])
else:
raise Exception("This should have been unreachable code.")

transform = kwargs.pop("transform", UNSET)
if size:
rv_out = change_rv_size(rv_var=rv_native, new_size=size, expand=True)
else:
rv_out = rv_native

rv_out = cls.dist(*args, rng=rng, **kwargs)
if rv_out.ndim != len(size) + implied_ndim:
raise ShapeError(
"Created RV has incorrect dimensionality.",
actual=rv_out.ndim,
expected=len(size) + implied_ndim,
)

if testval is not None:
# Assigning the testval earlier causes trouble because the RV may not be created with the final shape already.
rv_out.tag.test_value = testval

return model.register_rv(rv_out, name, data, total_size, dims=dims, transform=transform)

Expand Down
34 changes: 27 additions & 7 deletions pymc3/tests/test_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,22 +158,42 @@ def test_shared_data_as_rv_input(self):
"""
with pm.Model() as m:
x = pm.Data("x", [1.0, 2.0, 3.0])
_ = pm.Normal("y", mu=x, size=3)
trace = pm.sample(
chains=1, return_inferencedata=False, compute_convergence_checks=False
assert x.eval().shape == (3,)
y = pm.Normal("y", mu=x, size=2)
assert y.eval().shape == (2, 3)
idata = pm.sample(
chains=1,
tune=500,
draws=550,
return_inferencedata=True,
compute_convergence_checks=False,
)
samples = idata.posterior["y"]
assert samples.shape == (1, 550, 2, 3)

np.testing.assert_allclose(np.array([1.0, 2.0, 3.0]), x.get_value(), atol=1e-1)
np.testing.assert_allclose(np.array([1.0, 2.0, 3.0]), trace["y"].mean(0), atol=1e-1)
np.testing.assert_allclose(
np.array([1.0, 2.0, 3.0]), samples.mean(("chain", "draw", "y_dim_0")), atol=1e-1
)

with m:
pm.set_data({"x": np.array([2.0, 4.0, 6.0])})
trace = pm.sample(
chains=1, return_inferencedata=False, compute_convergence_checks=False
assert x.eval().shape == (3,)
assert y.eval().shape == (2, 3)
idata = pm.sample(
chains=1,
tune=500,
draws=620,
return_inferencedata=True,
compute_convergence_checks=False,
)
samples = idata.posterior["y"]
assert samples.shape == (1, 620, 2, 3)

np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), x.get_value(), atol=1e-1)
np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), trace["y"].mean(0), atol=1e-1)
np.testing.assert_allclose(
np.array([2.0, 4.0, 6.0]), samples.mean(("chain", "draw", "y_dim_0")), atol=1e-1
)

def test_shared_scalar_as_rv_input(self):
# See https://github.com/pymc-devs/pymc3/issues/3139
Expand Down
7 changes: 1 addition & 6 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,7 @@ def get_random_variable(self, shape, with_vector_params=False, name=None):
# in the test case parametrization "None" means "no specified (default)"
return self.distribution(name, transform=None, **params)
else:
ndim_supp = self.distribution.rv_op.ndim_supp
if ndim_supp == 0:
size = shape
else:
size = shape[:-ndim_supp]
return self.distribution(name, size=size, transform=None, **params)
return self.distribution(name, shape=shape, transform=None, **params)
except TypeError:
if np.sum(np.atleast_1d(shape)) == 0:
pytest.skip("Timeseries must have positive shape")
Expand Down
30 changes: 15 additions & 15 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def test_return_inferencedata(self, monkeypatch):
return_inferencedata=True,
discard_tuned_samples=True,
idata_kwargs={"prior": prior},
random_seed=-1
random_seed=-1,
)
assert "prior" in result
assert isinstance(result, InferenceData)
Expand Down Expand Up @@ -380,11 +380,11 @@ def test_shared_named(self):
"theta0",
mu=np.atleast_2d(0),
tau=np.atleast_2d(1e20),
size=(1, 1),
shape=(1, 1),
testval=np.atleast_2d(0),
)
theta = pm.Normal(
"theta", mu=at.dot(G_var, theta0), tau=np.atleast_2d(1e20), size=(1, 1)
"theta", mu=at.dot(G_var, theta0), tau=np.atleast_2d(1e20), shape=(1, 1)
)
res = theta.eval()
assert np.isclose(res, 0.0)
Expand All @@ -396,11 +396,11 @@ def test_shared_unnamed(self):
"theta0",
mu=np.atleast_2d(0),
tau=np.atleast_2d(1e20),
size=(1, 1),
shape=(1, 1),
testval=np.atleast_2d(0),
)
theta = pm.Normal(
"theta", mu=at.dot(G_var, theta0), tau=np.atleast_2d(1e20), size=(1, 1)
"theta", mu=at.dot(G_var, theta0), tau=np.atleast_2d(1e20), shape=(1, 1)
)
res = theta.eval()
assert np.isclose(res, 0.0)
Expand All @@ -412,11 +412,11 @@ def test_constant_named(self):
"theta0",
mu=np.atleast_2d(0),
tau=np.atleast_2d(1e20),
size=(1, 1),
shape=(1, 1),
testval=np.atleast_2d(0),
)
theta = pm.Normal(
"theta", mu=at.dot(G_var, theta0), tau=np.atleast_2d(1e20), size=(1, 1)
"theta", mu=at.dot(G_var, theta0), tau=np.atleast_2d(1e20), shape=(1, 1)
)

res = theta.eval()
Expand Down Expand Up @@ -931,14 +931,14 @@ def test_ignores_observed(self):
npt.assert_array_almost_equal(prior["positive_mu"], np.abs(prior["mu"]), decimal=4)

def test_respects_shape(self):
for shape in (2, (2,), (10, 2), (10, 10)):
for shape in ((2,), (10, 2), (10, 10)):
with pm.Model():
mu = pm.Gamma("mu", 3, 1, size=1)
goals = pm.Poisson("goals", mu, size=shape)
mu = pm.Gamma("mu", 3, 1)
assert mu.eval().shape == ()
goals = pm.Poisson("goals", mu, shape=shape)
assert goals.eval().shape == shape, f"Current shape setting: {shape}"
trace1 = pm.sample_prior_predictive(10, var_names=["mu", "mu", "goals"])
trace2 = pm.sample_prior_predictive(10, var_names=["mu", "goals"])
if shape == 2: # want to test shape as an int
shape = (2,)
assert trace1["goals"].shape == (10,) + shape
assert trace2["goals"].shape == (10,) + shape

Expand Down Expand Up @@ -966,14 +966,14 @@ def test_multivariate2(self):

def test_layers(self):
with pm.Model() as model:
a = pm.Uniform("a", lower=0, upper=1, size=10)
b = pm.Binomial("b", n=1, p=a, size=10)
a = pm.Uniform("a", lower=0, upper=1, size=5)
b = pm.Binomial("b", n=1, p=a, size=7)

model.default_rng.get_value(borrow=True).seed(232093)

b_sampler = aesara.function([], b)
avg = np.stack([b_sampler() for i in range(10000)]).mean(0)
npt.assert_array_almost_equal(avg, 0.5 * np.ones((10,)), decimal=2)
npt.assert_array_almost_equal(avg, 0.5 * np.ones((7, 5)), decimal=2)

def test_transformed(self):
n = 18
Expand Down
Loading

0 comments on commit fa43a1a

Please sign in to comment.