Skip to content

Commit

Permalink
Re-enable passing dims alongside shape or size (#5325)
Browse files Browse the repository at this point in the history
* Remove unused return value from helper functions
* Restore support for passing `dims` alongside `shape` or `size`
* Extract RV creation and `resize_shape` determination code

Closes #4656
  • Loading branch information
michaelosthege authored Jan 9, 2022
1 parent 6570e95 commit 9155922
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 75 deletions.
91 changes: 46 additions & 45 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@

from abc import ABCMeta
from functools import singledispatch
from typing import Callable, Iterable, Optional, Sequence
from typing import Callable, Iterable, Optional, Sequence, Tuple, Union

import aesara
import numpy as np

from aeppl.logprob import _logcdf, _logprob
from aesara import tensor as at
from aesara.graph.basic import Variable
from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.op import RandomVariable
Expand All @@ -36,6 +38,8 @@
Dims,
Shape,
Size,
StrongShape,
WeakDims,
convert_dims,
convert_shape,
convert_size,
Expand Down Expand Up @@ -133,6 +137,37 @@ def fn(*args, **kwargs):
return fn


def _make_rv_and_resize_shape(
*,
cls,
dims: Optional[Dims],
model,
observed,
args,
**kwargs,
) -> Tuple[Variable, Optional[WeakDims], Optional[Union[np.ndarray, Variable]], StrongShape]:
"""Creates the RV and processes dims or observed to determine a resize shape."""
# Create the RV without dims information, because that's not something tracked at the Aesara level.
# If necessary we'll later replicate to a different size implied by already known dims.
rv_out = cls.dist(*args, **kwargs)
ndim_actual = rv_out.ndim
resize_shape = None

# # `dims` are only available with this API, because `.dist()` can be used
# # without a modelcontext and dims are not tracked at the Aesara level.
dims = convert_dims(dims)
dims_can_resize = kwargs.get("shape", None) is None and kwargs.get("size", None) is None
if dims is not None:
if dims_can_resize:
resize_shape, dims = resize_from_dims(dims, ndim_actual, model)
elif Ellipsis in dims:
# Replace ... with None entries to match the actual dimensionality.
dims = (*dims[:-1], *[None] * ndim_actual)[:ndim_actual]
elif observed is not None:
resize_shape, observed = resize_from_observed(observed, ndim_actual)
return rv_out, dims, observed, resize_shape


class Distribution(metaclass=DistributionMeta):
"""Statistical distribution"""

Expand Down Expand Up @@ -213,28 +248,11 @@ def __new__(
if rng is None:
rng = model.next_rng()

if dims is not None and "shape" in kwargs:
raise ValueError(
f"Passing both `dims` ({dims}) and `shape` ({kwargs['shape']}) is not supported!"
)
if dims is not None and "size" in kwargs:
raise ValueError(
f"Passing both `dims` ({dims}) and `size` ({kwargs['size']}) is not supported!"
)
dims = convert_dims(dims)

# Create the RV without dims information, because that's not something tracked at the Aesara level.
# If necessary we'll later replicate to a different size implied by already known dims.
rv_out = cls.dist(*args, rng=rng, **kwargs)
ndim_actual = rv_out.ndim
resize_shape = None

# `dims` are only available with this API, because `.dist()` can be used
# without a modelcontext and dims are not tracked at the Aesara level.
if dims is not None:
ndim_resize, resize_shape, dims = resize_from_dims(dims, ndim_actual, model)
elif observed is not None:
ndim_resize, resize_shape, observed = resize_from_observed(observed, ndim_actual)
# Create the RV and process dims and observed to determine
# a shape by which the created RV may need to be resized.
rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
cls=cls, dims=dims, model=model, observed=observed, args=args, rng=rng, **kwargs
)

if resize_shape:
# A batch size was specified through `dims`, or implied by `observed`.
Expand Down Expand Up @@ -456,35 +474,18 @@ def __new__(
if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")

if dims is not None and "shape" in kwargs:
raise ValueError(
f"Passing both `dims` ({dims}) and `shape` ({kwargs['shape']}) is not supported!"
)
if dims is not None and "size" in kwargs:
raise ValueError(
f"Passing both `dims` ({dims}) and `size` ({kwargs['size']}) is not supported!"
)
dims = convert_dims(dims)

if rngs is None:
# Create a temporary rv to obtain number of rngs needed
temp_graph = cls.dist(*args, rngs=None, **kwargs)
rngs = [model.next_rng() for _ in cls.graph_rvs(temp_graph)]
elif not isinstance(rngs, (list, tuple)):
rngs = [rngs]

# Create the RV without dims information, because that's not something tracked at the Aesara level.
# If necessary we'll later replicate to a different size implied by already known dims.
rv_out = cls.dist(*args, rngs=rngs, **kwargs)
ndim_actual = rv_out.ndim
resize_shape = None

# # `dims` are only available with this API, because `.dist()` can be used
# # without a modelcontext and dims are not tracked at the Aesara level.
if dims is not None:
ndim_resize, resize_shape, dims = resize_from_dims(dims, ndim_actual, model)
elif observed is not None:
ndim_resize, resize_shape, observed = resize_from_observed(observed, ndim_actual)
# Create the RV and process dims and observed to determine
# a shape by which the created RV may need to be resized.
rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
cls=cls, dims=dims, model=model, observed=observed, args=args, rngs=rngs, **kwargs
)

if resize_shape:
# A batch size was specified through `dims`, or implied by `observed`.
Expand Down
22 changes: 9 additions & 13 deletions pymc/distributions/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def broadcast_dist_samples_to(to_shape, samples, size=None):
StrongSize = Union[TensorVariable, Tuple[Union[int, TensorVariable], ...]]


def convert_dims(dims: Dims) -> Optional[WeakDims]:
def convert_dims(dims: Optional[Dims]) -> Optional[WeakDims]:
"""Process a user-provided dims variable into None or a valid dims tuple."""
if dims is None:
return None
Expand Down Expand Up @@ -487,9 +487,7 @@ def convert_size(size: Size) -> Optional[StrongSize]:
return size


def resize_from_dims(
dims: WeakDims, ndim_implied: int, model
) -> Tuple[int, StrongSize, StrongDims]:
def resize_from_dims(dims: WeakDims, ndim_implied: int, model) -> Tuple[StrongSize, StrongDims]:
"""Determines a potential resize shape from a `dims` tuple.
Parameters
Expand All @@ -503,10 +501,10 @@ def resize_from_dims(
Returns
-------
ndim_resize : int
Number of dimensions that should be added through resizing.
resize_shape : array-like
The shape of the new dimensions.
Shape of new dimensions that should be prepended.
dims : tuple of (str or None)
Names or None for all dimensions after resizing.
"""
if Ellipsis in dims:
# Auto-complete the dims tuple to the full length.
Expand All @@ -525,12 +523,12 @@ def resize_from_dims(

# The numeric/symbolic resize tuple can be created using model.RV_dim_lengths
resize_shape = tuple(model.dim_lengths[dname] for dname in dims[:ndim_resize])
return ndim_resize, resize_shape, dims
return resize_shape, dims


def resize_from_observed(
observed, ndim_implied: int
) -> Tuple[int, StrongSize, Union[np.ndarray, Variable]]:
) -> Tuple[StrongSize, Union[np.ndarray, Variable]]:
"""Determines a potential resize shape from observations.
Parameters
Expand All @@ -542,18 +540,16 @@ def resize_from_observed(
Returns
-------
ndim_resize : int
Number of dimensions that should be added through resizing.
resize_shape : array-like
The shape of the new dimensions.
Shape of new dimensions that should be prepended.
observed : scalar, array-like
Observations as numpy array or `Variable`.
"""
if not hasattr(observed, "shape"):
observed = pandas_to_array(observed)
ndim_resize = observed.ndim - ndim_implied
resize_shape = tuple(observed.shape[d] for d in range(ndim_resize))
return ndim_resize, resize_shape, observed
return resize_shape, observed


def find_size(shape=None, size=None, ndim_supp=None):
Expand Down
58 changes: 41 additions & 17 deletions pymc/tests/test_shape_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,47 @@ def test_param_and_batch_shape_combos(
else:
raise NotImplementedError("Invalid test case parametrization.")

@pytest.mark.parametrize("ellipsis_in", ["none", "shape", "dims", "both"])
def test_simultaneous_shape_and_dims(self, ellipsis_in):
with pm.Model() as pmodel:
x = pm.ConstantData("x", [1, 2, 3], dims="ddata")

if ellipsis_in == "none":
# The shape and dims tuples correspond to each other.
# Note: No checks are performed that implied shape (x), shape and dims actually match.
y = pm.Normal("y", mu=x, shape=(2, 3), dims=("dshape", "ddata"))
assert pmodel.RV_dims["y"] == ("dshape", "ddata")
elif ellipsis_in == "shape":
y = pm.Normal("y", mu=x, shape=(2, ...), dims=("dshape", "ddata"))
assert pmodel.RV_dims["y"] == ("dshape", "ddata")
elif ellipsis_in == "dims":
y = pm.Normal("y", mu=x, shape=(2, 3), dims=("dshape", ...))
assert pmodel.RV_dims["y"] == ("dshape", None)
elif ellipsis_in == "both":
y = pm.Normal("y", mu=x, shape=(2, ...), dims=("dshape", ...))
assert pmodel.RV_dims["y"] == ("dshape", None)

assert "dshape" in pmodel.dim_lengths
assert y.eval().shape == (2, 3)

@pytest.mark.parametrize("with_dims_ellipsis", [False, True])
def test_simultaneous_size_and_dims(self, with_dims_ellipsis):
with pm.Model() as pmodel:
x = pm.ConstantData("x", [1, 2, 3], dims="ddata")
assert "ddata" in pmodel.dim_lengths

# Size does not include support dims, so this test must use a dist with support dims.
kwargs = dict(name="y", size=2, mu=at.ones((3, 4)), cov=at.eye(4))
if with_dims_ellipsis:
y = pm.MvNormal(**kwargs, dims=("dsize", ...))
assert pmodel.RV_dims["y"] == ("dsize", None, None)
else:
y = pm.MvNormal(**kwargs, dims=("dsize", "ddata", "dsupport"))
assert pmodel.RV_dims["y"] == ("dsize", "ddata", "dsupport")

assert "dsize" in pmodel.dim_lengths
assert y.eval().shape == (2, 3, 4)

def test_define_dims_on_the_fly(self):
with pm.Model() as pmodel:
agedata = aesara.shared(np.array([10, 20, 30]))
Expand All @@ -312,17 +353,6 @@ def test_define_dims_on_the_fly(self):
# The change should propagate all the way through
assert effect.eval().shape == (4,)

@pytest.mark.xfail(reason="Simultaneous use of size and dims is not implemented")
def test_data_defined_size_dimension_can_register_dimname(self):
with pm.Model() as pmodel:
x = pm.ConstantData("x", [[1, 2, 3, 4]], dims=("first", "second"))
assert "first" in pmodel.dim_lengths
assert "second" in pmodel.dim_lengths
# two dimensions are implied; a "third" dimension is created
y = pm.Normal("y", mu=x, size=2, dims=("third", "first", "second"))
assert "third" in pmodel.dim_lengths
assert y.eval().shape() == (2, 1, 4)

def test_can_resize_data_defined_size(self):
with pm.Model() as pmodel:
x = pm.MutableData("x", [[1, 2, 3, 4]], dims=("first", "second"))
Expand Down Expand Up @@ -447,9 +477,3 @@ def test_lazy_flavors(self):
def test_invalid_flavors(self):
with pytest.raises(ValueError, match="Passing both"):
pm.Normal.dist(0, 1, shape=(3,), size=(3,))

with pm.Model():
with pytest.raises(ValueError, match="Passing both"):
pm.Normal("n", shape=(2,), dims=("town",))
with pytest.raises(ValueError, match="Passing both"):
pm.Normal("n", dims=("town",), size=(2,))

0 comments on commit 9155922

Please sign in to comment.