Skip to content

Commit

Permalink
Merge branch 'fix-4662' into reintro_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
twiecki committed May 14, 2021
2 parents 08754cd + 3f8ea24 commit 512491a
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 1 deletion.
9 changes: 8 additions & 1 deletion pymc3/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
from aesara.tensor.var import TensorVariable

from pymc3.exceptions import ShapeError
from pymc3.vartypes import continuous_types, int_types, isgenerator, typefilter

PotentialShapeType = Union[
Expand Down Expand Up @@ -146,6 +147,12 @@ def change_rv_size(
Expand the existing size by `new_size`.
"""
new_size_ndim = new_size.ndim if isinstance(new_size, Variable) else np.ndim(new_size)
if new_size_ndim > 1:
raise ShapeError("The `new_size` must be ≤1-dimensional.", actual=new_size_ndim)
new_size = at.as_tensor_variable(new_size, ndim=1)
if isinstance(rv_var.owner.op, SpecifyShape):
rv_var = rv_var.owner.inputs[0]
rv_node = rv_var.owner
rng, size, dtype, *dist_params = rv_node.inputs
name = rv_var.name
Expand All @@ -154,7 +161,7 @@ def change_rv_size(
if expand:
if rv_node.op.ndim_supp == 0 and at.get_vector_length(size) == 0:
size = rv_node.op._infer_shape(size, dist_params)
new_size = tuple(np.atleast_1d(new_size)) + tuple(size)
new_size = tuple(new_size) + tuple(size)

new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params)
rv_var = new_rv_node.outputs[-1]
Expand Down
6 changes: 6 additions & 0 deletions pymc3/tests/test_aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
take_along_axis,
walk_model,
)
from pymc3.exceptions import ShapeError
from pymc3.vartypes import int_types

FLOATX = str(aesara.config.floatX)
Expand All @@ -53,6 +54,11 @@ def test_change_rv_size():
assert rv.ndim == 1
assert rv.eval().shape == (2,)

with pytest.raises(ShapeError, match="must be ≤1-dimensional"):
change_rv_size(rv, new_size=[[2, 3]])
with pytest.raises(ShapeError, match="must be ≤1-dimensional"):
change_rv_size(rv, new_size=at.as_tensor_variable([[2, 3], [4, 5]]))

rv_new = change_rv_size(rv, new_size=(3,), expand=True)
assert rv_new.ndim == 2
assert rv_new.eval().shape == (3, 2)
Expand Down
5 changes: 5 additions & 0 deletions pymc3/tests/test_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

import aesara
import numpy as np
import pytest
Expand Down Expand Up @@ -168,6 +170,9 @@ def ode_func_5(y, t, p):
np.testing.assert_array_equal(np.ravel(model5_sens_ic), model5._sens_ic)


@pytest.mark.xfail(
condition=sys.platform == "win32", reason="See https://github.com/pymc-devs/pymc3/issues/4652."
)
def test_logp_scalar_ode():
"""Test the computation of the log probability for these models"""

Expand Down
197 changes: 197 additions & 0 deletions pymc3/tests/test_shape_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,200 @@ def test_sample_generate_values(fixture_model, fixture_sizes):
prior = pm.sample_prior_predictive(samples=fixture_sizes)
for rv in RVs:
assert prior[rv.name].shape == size + tuple(rv.distribution.shape)


class TestShapeDimsSize:
@pytest.mark.parametrize("param_shape", [(), (3,)])
@pytest.mark.parametrize("batch_shape", [(), (3,)])
@pytest.mark.parametrize(
"parametrization",
[
"implicit",
"shape",
"shape...",
"dims",
"dims...",
"size",
],
)
def test_param_and_batch_shape_combos(
self, param_shape: tuple, batch_shape: tuple, parametrization: str
):
coords = {}
param_dims = []
batch_dims = []

# Create coordinates corresponding to the parameter shape
for d in param_shape:
dname = f"param_dim_{d}"
coords[dname] = [f"c_{i}" for i in range(d)]
param_dims.append(dname)
assert len(param_dims) == len(param_shape)
# Create coordinates corresponding to the batch shape
for d in batch_shape:
dname = f"batch_dim_{d}"
coords[dname] = [f"c_{i}" for i in range(d)]
batch_dims.append(dname)
assert len(batch_dims) == len(batch_shape)

with pm.Model(coords=coords) as pmodel:
mu = aesara.shared(np.random.normal(size=param_shape))

with pytest.warns(None):
if parametrization == "implicit":
rv = pm.Normal("rv", mu=mu).shape == param_shape
else:
if parametrization == "shape":
rv = pm.Normal("rv", mu=mu, shape=batch_shape + param_shape)
assert rv.eval().shape == batch_shape + param_shape
elif parametrization == "shape...":
rv = pm.Normal("rv", mu=mu, shape=(*batch_shape, ...))
assert rv.eval().shape == batch_shape + param_shape
elif parametrization == "dims":
rv = pm.Normal("rv", mu=mu, dims=batch_dims + param_dims)
assert rv.eval().shape == batch_shape + param_shape
elif parametrization == "dims...":
rv = pm.Normal("rv", mu=mu, dims=(*batch_dims, ...))
n_size = len(batch_shape)
n_implied = len(param_shape)
ndim = n_size + n_implied
assert len(pmodel.RV_dims["rv"]) == ndim, pmodel.RV_dims
assert len(pmodel.RV_dims["rv"][:n_size]) == len(batch_dims)
assert len(pmodel.RV_dims["rv"][n_size:]) == len(param_dims)
if n_implied > 0:
assert pmodel.RV_dims["rv"][-1] is None
elif parametrization == "size":
rv = pm.Normal("rv", mu=mu, size=batch_shape)
assert rv.eval().shape == batch_shape + param_shape
else:
raise NotImplementedError("Invalid test case parametrization.")

def test_define_dims_on_the_fly(self):
with pm.Model() as pmodel:
agedata = aesara.shared(np.array([10, 20, 30]))

# Associate the "patient" dim with an implied dimension
age = pm.Normal("age", agedata, dims=("patient",))
assert "patient" in pmodel.dim_lengths
assert pmodel.dim_lengths["patient"].eval() == 3

# Use the dim to replicate a new RV
effect = pm.Normal("effect", 0, dims=("patient",))
assert effect.ndim == 1
assert effect.eval().shape == (3,)

# Now change the length of the implied dimension
agedata.set_value([1, 2, 3, 4])
# The change should propagate all the way through
assert effect.eval().shape == (4,)

@pytest.mark.xfail(reason="Simultaneous use of size and dims is not implemented")
def test_data_defined_size_dimension_can_register_dimname(self):
with pm.Model() as pmodel:
x = pm.Data("x", [[1, 2, 3, 4]], dims=("first", "second"))
assert "first" in pmodel.dim_lengths
assert "second" in pmodel.dim_lengths
# two dimensions are implied; a "third" dimension is created
y = pm.Normal("y", mu=x, size=2, dims=("third", "first", "second"))
assert "third" in pmodel.dim_lengths
assert y.eval().shape() == (2, 1, 4)

def test_can_resize_data_defined_size(self):
with pm.Model() as pmodel:
x = pm.Data("x", [[1, 2, 3, 4]], dims=("first", "second"))
y = pm.Normal("y", mu=0, dims=("first", "second"))
z = pm.Normal("z", mu=y, observed=np.ones((1, 4)))
assert x.eval().shape == (1, 4)
assert y.eval().shape == (1, 4)
assert z.eval().shape == (1, 4)
assert "first" in pmodel.dim_lengths
assert "second" in pmodel.dim_lengths
pmodel.set_data("x", [[1, 2], [3, 4], [5, 6]])
assert x.eval().shape == (3, 2)
assert y.eval().shape == (3, 2)
assert z.eval().shape == (3, 2)

@pytest.mark.xfail(
condition=sys.platform == "win32",
reason="See https://github.com/pymc-devs/pymc3/issues/4652.",
)
def test_observed_with_column_vector(self):
with pm.Model() as model:
pm.Normal("x1", mu=0, sd=1, observed=np.random.normal(size=(3, 4)))
model.logp()
pm.Normal("x2", mu=0, sd=1, observed=np.random.normal(size=(3, 1)))
model.logp()

def test_dist_api_works(self):
mu = aesara.shared(np.array([1, 2, 3]))
with pytest.raises(NotImplementedError, match="API is not yet supported"):
pm.Normal.dist(mu=mu, dims=("town",))
assert pm.Normal.dist(mu=mu, shape=(3,)).eval().shape == (3,)
assert pm.Normal.dist(mu=mu, shape=(5, 3)).eval().shape == (5, 3)
assert pm.Normal.dist(mu=mu, shape=(7, ...)).eval().shape == (7, 3)
assert pm.Normal.dist(mu=mu, size=(4,)).eval().shape == (4, 3)

def test_auto_assert_shape(self):
with pytest.raises(AssertionError, match="will never match"):
pm.Normal.dist(mu=[1, 2], shape=[])

mu = at.vector(name="mu_input")
rv = pm.Normal.dist(mu=mu, shape=[3, 4])
f = aesara.function([mu], rv, mode=aesara.Mode("py"))
assert f([1, 2, 3, 4]).shape == (3, 4)

with pytest.raises(AssertionError, match=r"Got shape \(3, 2\), expected \(3, 4\)."):
f([1, 2])

# The `shape` can be symbolic!
s = at.vector(dtype="int32")
rv = pm.Uniform.dist(2, [4, 5], shape=s)
f = aesara.function([s], rv, mode=aesara.Mode("py"))
f(
[
2,
]
)
with pytest.raises(
AssertionError,
match=r"Got 1 dimensions \(shape \(2,\)\), expected 2 dimensions with shape \(3, 4\).",
):
f([3, 4])
with pytest.raises(
AssertionError,
match=r"Got 1 dimensions \(shape \(2,\)\), expected 0 dimensions with shape \(\).",
):
f([])
pass

def test_lazy_flavors(self):

_validate_shape_dims_size(shape=5)
_validate_shape_dims_size(dims="town")
_validate_shape_dims_size(size=7)

assert pm.Uniform.dist(2, [4, 5], size=[3, 4]).eval().shape == (3, 4, 2)
assert pm.Uniform.dist(2, [4, 5], shape=[3, 2]).eval().shape == (3, 2)
with pm.Model(coords=dict(town=["Greifswald", "Madrid"])):
assert pm.Normal("n2", mu=[1, 2], dims=("town",)).eval().shape == (2,)

def test_invalid_flavors(self):
# redundant parametrizations
with pytest.raises(ValueError, match="Passing both"):
_validate_shape_dims_size(shape=(2,), dims=("town",))
with pytest.raises(ValueError, match="Passing both"):
_validate_shape_dims_size(dims=("town",), size=(2,))
with pytest.raises(ValueError, match="Passing both"):
_validate_shape_dims_size(shape=(3,), size=(3,))

# invalid, but not necessarly rare
with pytest.raises(ValueError, match="must be an int, list or tuple"):
_validate_shape_dims_size(size="notasize")

# invalid ellipsis positions
with pytest.raises(ValueError, match="may only appear in the last position"):
_validate_shape_dims_size(shape=(3, ..., 2))
with pytest.raises(ValueError, match="may only appear in the last position"):
_validate_shape_dims_size(dims=(..., "town"))
with pytest.raises(ValueError, match="cannot contain"):
_validate_shape_dims_size(size=(3, ...))

0 comments on commit 512491a

Please sign in to comment.