Skip to content

Commit

Permalink
Implement shape/dims/size API
Browse files Browse the repository at this point in the history
Dims with Ellipsis are not yet implemented.
Some tests were refactored because size is now implemented more consistently.
  • Loading branch information
michaelosthege committed Apr 11, 2021
1 parent 116cdf3 commit d3fde48
Show file tree
Hide file tree
Showing 10 changed files with 453 additions and 102 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

### New Features
- The `CAR` distribution has been added to allow for use of conditional autoregressions which often are used in spatial and network models.
- The shape handling of random variables was improved dramatically. For an overview refer to https://github.com/pymc-devs/pymc3/wiki/v4-shapes (also see [#4625](https://github.com/pymc-devs/pymc3/pull/4625))
- ...

### Maintenance
Expand Down
237 changes: 209 additions & 28 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,28 @@
# limitations under the License.
import contextvars
import inspect
import logging
import multiprocessing
import sys
import types
import warnings

from abc import ABCMeta
from copy import copy
from typing import TYPE_CHECKING
from typing import Any, Optional, Tuple, Union

import aesara
import aesara.graph.basic
import aesara.tensor as at
import dill

from aesara.graph.basic import Variable
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import specify_shape

from pymc3.aesaraf import change_rv_size
from pymc3.distributions import _logcdf, _logp

if TYPE_CHECKING:
from typing import Optional, Callable

import aesara
import aesara.graph.basic
import aesara.tensor as at

from pymc3.exceptions import BestPracticeWarning
from pymc3.util import UNSET, get_repr_for_variable
from pymc3.vartypes import string_types

Expand All @@ -46,12 +46,18 @@
"NoDistribution",
]

_log = logging.getLogger(__file__)

vectorized_ppc = contextvars.ContextVar(
"vectorized_ppc", default=None
) # type: contextvars.ContextVar[Optional[Callable]]

PLATFORM = sys.platform

Shape = Union[Tuple[Union[str, type(Ellipsis)]], Variable]
Dims = Tuple[Union[str, type(Ellipsis)]]
Size = Union[int, Tuple[int]]


class _Unpickling:
pass
Expand Down Expand Up @@ -122,13 +128,114 @@ def logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
return new_cls


def _valid_ellipsis_position(items: Union[None, Shape, Dims, Size]) -> bool:
if items is not None and not isinstance(items, Variable) and Ellipsis in items:
if any(i == Ellipsis for i in items[:-1]):
return False
return True


def _validate_shape_dims_size(
shape: Any = None, dims: Any = None, size: Any = None
) -> Tuple[Optional[Shape], Optional[Dims], Optional[Size]]:
# Raise on unsupported parametrization
if shape is not None and dims is not None:
raise ValueError("Passing both `shape` ({shape}) and `dims` ({dims}) is not supported!")
if dims is not None and size is not None:
raise ValueError("Passing both `dims` ({dims}) and `size` ({size}) is not supported!")
if shape is not None and size is not None:
raise ValueError("Passing both `shape` ({shape}) and `size` ({size}) is not supported!")

# Raise on invalid types
if not isinstance(shape, (type(None), int, list, tuple, aesara.graph.basic.Variable)):
raise ValueError("The `shape` parameter must be an int, list or tuple.")
if not isinstance(dims, (type(None), str, list, tuple)):
raise ValueError("The `dims` parameter must be a str, list or tuple.")
if not isinstance(size, (type(None), int, list, tuple)):
raise ValueError("The `size` parameter must be an int, list or tuple.")

# Warn about backwards-compatible but discouraged parametrization
if isinstance(shape, int):
warnings.warn("The `shape` parameter should be a list or tuple.", BestPracticeWarning)
shape = (shape,)
if isinstance(dims, str):
warnings.warn("The `dims` parameter should be a list or tuple.", BestPracticeWarning)
dims = (dims,)
if isinstance(size, int):
# size as int is okay, because that's the most common use case
size = (size,)

# Convert to actual tuples
if not isinstance(shape, (type(None), tuple, Variable)):
shape = tuple(shape)
if not isinstance(dims, (type(None), tuple)):
dims = tuple(dims)
if not isinstance(size, (type(None), tuple)):
size = tuple(size)

if not _valid_ellipsis_position(shape):
raise ValueError(
f"Ellipsis in `shape` may only appear in the last position. Actual: {shape}"
)
if not _valid_ellipsis_position(dims):
raise ValueError(f"Ellipsis in `dims` may only appear in the last position. Actual: {dims}")
if size is not None and Ellipsis in size:
raise ValueError("The `size` parameter cannot contain an Ellipsis. Actual: {size}")
return shape, dims, size


class Distribution(metaclass=DistributionMeta):
"""Statistical distribution"""

rv_class = None
rv_op = None

def __new__(cls, name, *args, **kwargs):
def __new__(
cls,
name: str,
*args,
rng=None,
dims: Optional[Dims] = None,
testval=None,
observed=None,
total_size=None,
transform=UNSET,
**kwargs,
) -> RandomVariable:
"""Adds a RandomVariable corresponding to a PyMC3 distribution to the current model.
Note that all remaining kwargs must be compatible with .dist()
Parameters
----------
cls : type
A PyMC3 distribution.
name : str
Name for the new model variable.
rng : optional
Random number generator to use with the RandomVariable.
dims : tuple, optional
A tuple of dimension names known to the model.
testval : optional
Test value to be attached to the output RV.
Must match its shape exactly.
observed : optional
Observed data to be passed when registering the random variable in the model.
See `Model.register_rv`.
total_size : float, optional
See `Model.register_rv`.
transform : optional
See `Model.register_rv`.
**kwargs
Keyword arguments that will be forwarded to .dist().
Most prominently: `shape` and `size`
Returns
-------
rv : RandomVariable
The created RV, registered in the Model.
"""

try:
from pymc3.model import Model

Expand All @@ -141,40 +248,114 @@ def __new__(cls, name, *args, **kwargs):
"for a standalone distribution."
)

rng = kwargs.pop("rng", None)

if rng is None:
rng = model.default_rng

if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")

data = kwargs.pop("observed", None)
if rng is None:
rng = model.default_rng

total_size = kwargs.pop("total_size", None)
# Create the RV without specifying testval, because the testval may
# have a shape that only matches after replicating with a size implied
# by dims (see below).
rv_out = cls.dist(*args, rng=rng, testval=None, **kwargs)

dims = kwargs.pop("dims", None)
# `dims` are only available with this API, because `.dist()` can be used
# without a modelcontext and dims are not yet tracked at the Aesara level.
if dims is not None:
# Infer size from dims (and coords)
if Ellipsis in dims:
raise NotImplementedError("Ellipsis-based dims are not implemented.")
# TODO: Get implied dimensions from input Variables! (see https://github.com/pymc-devs/aesara/issues/352)
implied_dims = ...
dims = tuple(dims[:-1]) + implied_dims

if "shape" in kwargs:
raise DeprecationWarning("The `shape` keyword is deprecated; use `size`.")
dimshape = tuple(len(model.coords[dname]) for dname in dims)
size = tuple(dimshape[: len(dimshape) - rv_out.ndim])

transform = kwargs.pop("transform", UNSET)
if size:
# A batch size was specified through dims!
rv_out = change_rv_size(rv_var=rv_out, new_size=size, expand=True)

rv_out = cls.dist(*args, rng=rng, **kwargs)
if testval is not None:
# Assigning the testval earlier causes trouble because the RV may not be created with the final shape already.
rv_out.tag.test_value = testval

return model.register_rv(rv_out, name, data, total_size, dims=dims, transform=transform)
return model.register_rv(rv_out, name, observed, total_size, dims=dims, transform=transform)

@classmethod
def dist(cls, dist_params, **kwargs):
def dist(
cls,
dist_params,
*,
shape: Optional[Shape] = None,
size: Optional[Size] = None,
testval=None,
**kwargs,
) -> RandomVariable:
"""Creates a RandomVariable corresponding to the `cls` distribution.
Parameters
----------
dist_params
shape : tuple, optional
A tuple of sizes for each dimension of the new RV.
Ellipsis (...) may be used in the last position of the tuple,
and automatically expand to the shape implied by RV inputs.
Without Ellipsis, a `SpecifyShape` Op is automatically applied,
constraining this model variable to exactly the specified shape.
size : int, tuple, Variable, optional
A scalar or tuple for replicating the RV in addition
to its implied shape/dimensionality.
testval : optional
Test value to be attached to the output RV.
Must match its shape exactly.
testval = kwargs.pop("testval", None)
Returns
-------
rv : RandomVariable
The created RV.
"""
if "dims" in kwargs:
raise NotImplementedError("The use of a `.dist(dims=...)` API is not yet supported.")

shape, _, size = _validate_shape_dims_size(shape=shape, size=size)
assert_shape = None

# Create the RV without specifying size or testval.
# The size will be expanded later (if necessary) and only then the testval fits.
rv_native = cls.rv_op(*dist_params, size=None, **kwargs)

if shape is None and size is None:
size = ()
elif shape is not None:
# SpecifyShape is automatically applied for symbolic and non-Ellipsis shapes
if isinstance(shape, Variable):
assert_shape = shape
size = ()
else:
if Ellipsis in shape:
size = tuple(shape[:-1])
else:
size = tuple(shape[: len(shape) - rv_native.ndim])
assert_shape = shape
# no-op conditions:
# `elif size is not None` (User already specified how to expand the RV)
# `else` (Unreachable)

if size:
rv_out = change_rv_size(rv_var=rv_native, new_size=size, expand=True)
else:
rv_out = rv_native

rv_var = cls.rv_op(*dist_params, **kwargs)
if assert_shape is not None:
rv_out = specify_shape(rv_out, shape=assert_shape)

if testval is not None:
rv_var.tag.test_value = testval
rv_out.tag.test_value = testval

return rv_var
return rv_out

def _distr_parameters_for_repr(self):
"""Return the names of the parameters for this distribution (e.g. "mu"
Expand Down
4 changes: 4 additions & 0 deletions pymc3/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
]


class BestPracticeWarning(UserWarning):
pass


class SamplingError(RuntimeError):
pass

Expand Down
2 changes: 1 addition & 1 deletion pymc3/tests/sampler_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class BetaBinomialFixture(KnownCDF):
@classmethod
def make_model(cls):
with pm.Model() as model:
p = pm.Beta("p", [0.5, 0.5, 1.0], [0.5, 0.5, 1.0], size=3)
p = pm.Beta("p", [0.5, 0.5, 1.0], [0.5, 0.5, 1.0], shape=(3,))
pm.Binomial("y", p=p, n=[4, 12, 9], observed=[1, 2, 9])
return model

Expand Down
34 changes: 27 additions & 7 deletions pymc3/tests/test_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,22 +158,42 @@ def test_shared_data_as_rv_input(self):
"""
with pm.Model() as m:
x = pm.Data("x", [1.0, 2.0, 3.0])
_ = pm.Normal("y", mu=x, size=3)
trace = pm.sample(
chains=1, return_inferencedata=False, compute_convergence_checks=False
assert x.eval().shape == (3,)
y = pm.Normal("y", mu=x, size=2)
assert y.eval().shape == (2, 3)
idata = pm.sample(
chains=1,
tune=500,
draws=550,
return_inferencedata=True,
compute_convergence_checks=False,
)
samples = idata.posterior["y"]
assert samples.shape == (1, 550, 2, 3)

np.testing.assert_allclose(np.array([1.0, 2.0, 3.0]), x.get_value(), atol=1e-1)
np.testing.assert_allclose(np.array([1.0, 2.0, 3.0]), trace["y"].mean(0), atol=1e-1)
np.testing.assert_allclose(
np.array([1.0, 2.0, 3.0]), samples.mean(("chain", "draw", "y_dim_0")), atol=1e-1
)

with m:
pm.set_data({"x": np.array([2.0, 4.0, 6.0])})
trace = pm.sample(
chains=1, return_inferencedata=False, compute_convergence_checks=False
assert x.eval().shape == (3,)
assert y.eval().shape == (2, 3)
idata = pm.sample(
chains=1,
tune=500,
draws=620,
return_inferencedata=True,
compute_convergence_checks=False,
)
samples = idata.posterior["y"]
assert samples.shape == (1, 620, 2, 3)

np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), x.get_value(), atol=1e-1)
np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), trace["y"].mean(0), atol=1e-1)
np.testing.assert_allclose(
np.array([2.0, 4.0, 6.0]), samples.mean(("chain", "draw", "y_dim_0")), atol=1e-1
)

def test_shared_scalar_as_rv_input(self):
# See https://github.com/pymc-devs/pymc3/issues/3139
Expand Down
7 changes: 1 addition & 6 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,7 @@ def get_random_variable(self, shape, with_vector_params=False, name=None):
# in the test case parametrization "None" means "no specified (default)"
return self.distribution(name, transform=None, **params)
else:
ndim_supp = self.distribution.rv_op.ndim_supp
if ndim_supp == 0:
size = shape
else:
size = shape[:-ndim_supp]
return self.distribution(name, size=size, transform=None, **params)
return self.distribution(name, shape=shape, transform=None, **params)
except TypeError:
if np.sum(np.atleast_1d(shape)) == 0:
pytest.skip("Timeseries must have positive shape")
Expand Down
Loading

0 comments on commit d3fde48

Please sign in to comment.