Skip to content

Commit

Permalink
Automatically add SpecifyShape Op when full-length shape is given
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege authored and twiecki committed Apr 20, 2021
1 parent ed29203 commit c99f15c
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 3 deletions.
2 changes: 1 addition & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
### New Features
- The `CAR` distribution has been added to allow for use of conditional autoregressions which often are used in spatial and network models.
- The dimensionality of model variables can now be parametrized through either of `shape`, `dims` or `size` (see [#4625](https://github.com/pymc-devs/pymc3/pull/4625)):
- With `shape` the length of dimensions must be given numerically or as scalar Aesara `Variables`. Using `shape` restricts the model variable to the exact length and re-sizing is no longer possible.
- With `shape` the length of dimensions must be given numerically or as scalar Aesara `Variables`. A `SpecifyShape` `Op` is added automatically unless `Ellipsis` is used. Using `shape` restricts the model variable to the exact length and re-sizing is no longer possible.
- `dims` keeps model variables re-sizeable (for example through `pm.Data`) and leads to well defined coordinates in `InferenceData` objects.
- The `size` kwarg creates new dimensions in addition to what is implied by RV parameters.
- An `Ellipsis` (`...`) in the last position of `shape` or `dims` can be used as short-hand notation for implied dimensions.
Expand Down
3 changes: 3 additions & 0 deletions pymc3/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import SpecifyShape
from aesara.tensor.sharedvar import SharedVariable
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
from aesara.tensor.var import TensorVariable
Expand Down Expand Up @@ -146,6 +147,8 @@ def change_rv_size(
Expand the existing size by `new_size`.
"""
if isinstance(rv_var.owner.op, SpecifyShape):
rv_var = rv_var.owner.inputs[0]
rv_node = rv_var.owner
rng, size, dtype, *dist_params = rv_node.inputs
name = rv_var.name
Expand Down
28 changes: 27 additions & 1 deletion pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from aesara.graph.basic import Variable
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import SpecifyShape, specify_shape

from pymc3.aesaraf import change_rv_size, pandas_to_array
from pymc3.distributions import _logcdf, _logp
Expand Down Expand Up @@ -253,6 +254,13 @@ def __new__(
rv_out = cls.dist(*args, rng=rng, testval=None, **kwargs)
n_implied = rv_out.ndim

# The `.dist()` can wrap automatically with a SpecifyShape Op which brings informative
# error messages earlier in model construction.
# Here, however, the underyling RV must be used - a new SpecifyShape Op can be added at the end.
assert_shape = None
if isinstance(rv_out.owner.op, SpecifyShape):
rv_out, assert_shape = rv_out.owner.inputs

# `dims` are only available with this API, because `.dist()` can be used
# without a modelcontext and dims are not tracked at the Aesara level.
if dims is not None:
Expand Down Expand Up @@ -292,7 +300,15 @@ def __new__(
# Assigning the testval earlier causes trouble because the RV may not be created with the final shape already.
rv_out.tag.test_value = testval

return model.register_rv(rv_out, name, observed, total_size, dims=dims, transform=transform)
rv_registered = model.register_rv(
rv_out, name, observed, total_size, dims=dims, transform=transform
)

# Wrapping in specify_shape now does not break transforms:
if assert_shape is not None:
rv_registered = specify_shape(rv_registered, assert_shape)

return rv_registered

@classmethod
def dist(
Expand All @@ -314,6 +330,9 @@ def dist(
Ellipsis (...) may be used in the last position of the tuple,
and automatically expand to the shape implied by RV inputs.
Without Ellipsis, a `SpecifyShape` Op is automatically applied,
constraining this model variable to exactly the specified shape.
size : int, tuple, Variable, optional
A scalar or tuple for replicating the RV in addition
to its implied shape/dimensionality.
Expand All @@ -330,6 +349,7 @@ def dist(
raise NotImplementedError("The use of a `.dist(dims=...)` API is not yet supported.")

shape, _, size = _validate_shape_dims_size(shape=shape, size=size)
assert_shape = None

# Create the RV without specifying size or testval.
# The size will be expanded later (if necessary) and only then the testval fits.
Expand All @@ -338,13 +358,16 @@ def dist(
if shape is None and size is None:
size = ()
elif shape is not None:
# SpecifyShape is automatically applied for symbolic and non-Ellipsis shapes
if isinstance(shape, Variable):
assert_shape = shape
size = ()
else:
if Ellipsis in shape:
size = tuple(shape[:-1])
else:
size = tuple(shape[: len(shape) - rv_native.ndim])
assert_shape = shape
# no-op conditions:
# `elif size is not None` (User already specified how to expand the RV)
# `else` (Unreachable)
Expand All @@ -354,6 +377,9 @@ def dist(
else:
rv_out = rv_native

if assert_shape is not None:
rv_out = specify_shape(rv_out, shape=assert_shape)

if testval is not None:
rv_out.tag.test_value = testval

Expand Down
2 changes: 1 addition & 1 deletion pymc3/tests/test_logp.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_logpt_incsubtensor(indices, shape):
sigma = 0.001
rng = aesara.shared(np.random.RandomState(232), borrow=True)

a = Normal.dist(mu, sigma, shape=shape, rng=rng)
a = Normal.dist(mu, sigma, rng=rng)
a.name = "a"

a_idx = at.set_subtensor(a[indices], data)
Expand Down
33 changes: 33 additions & 0 deletions pymc3/tests/test_shape_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,39 @@ def test_dist_api_works(self):
assert pm.Normal.dist(mu=mu, shape=(7, ...)).eval().shape == (7, 3)
assert pm.Normal.dist(mu=mu, size=(4,)).eval().shape == (4, 3)

def test_auto_assert_shape(self):
with pytest.raises(AssertionError, match="will never match"):
pm.Normal.dist(mu=[1, 2], shape=[])

mu = at.vector(name="mu_input")
rv = pm.Normal.dist(mu=mu, shape=[3, 4])
f = aesara.function([mu], rv, mode=aesara.Mode("py"))
assert f([1, 2, 3, 4]).shape == (3, 4)

with pytest.raises(AssertionError, match=r"Got shape \(3, 2\), expected \(3, 4\)."):
f([1, 2])

# The `shape` can be symbolic!
s = at.vector(dtype="int32")
rv = pm.Uniform.dist(2, [4, 5], shape=s)
f = aesara.function([s], rv, mode=aesara.Mode("py"))
f(
[
2,
]
)
with pytest.raises(
AssertionError,
match=r"Got 1 dimensions \(shape \(2,\)\), expected 2 dimensions with shape \(3, 4\).",
):
f([3, 4])
with pytest.raises(
AssertionError,
match=r"Got 1 dimensions \(shape \(2,\)\), expected 0 dimensions with shape \(\).",
):
f([])
pass

def test_lazy_flavors(self):

_validate_shape_dims_size(shape=5)
Expand Down

0 comments on commit c99f15c

Please sign in to comment.