Skip to content

Commit

Permalink
Restore support for passing dims alongside shape or size
Browse files Browse the repository at this point in the history
Closes #4656
  • Loading branch information
michaelosthege committed Jan 8, 2022
1 parent ca8f654 commit d0f12eb
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 28 deletions.
18 changes: 7 additions & 11 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,6 @@ def __new__(
if rng is None:
rng = model.next_rng()

if dims is not None and "shape" in kwargs:
raise ValueError(
f"Passing both `dims` ({dims}) and `shape` ({kwargs['shape']}) is not supported!"
)
if dims is not None and "size" in kwargs:
raise ValueError(
f"Passing both `dims` ({dims}) and `size` ({kwargs['size']}) is not supported!"
)
dims = convert_dims(dims)

# Create the RV without dims information, because that's not something tracked at the Aesara level.
# If necessary we'll later replicate to a different size implied by already known dims.
rv_out = cls.dist(*args, rng=rng, **kwargs)
Expand All @@ -230,8 +220,14 @@ def __new__(

# `dims` are only available with this API, because `.dist()` can be used
# without a modelcontext and dims are not tracked at the Aesara level.
dims = convert_dims(dims)
dims_can_resize = kwargs.get("shape", None) is None and kwargs.get("size", None) is None
if dims is not None:
resize_shape, dims = resize_from_dims(dims, ndim_actual, model)
if dims_can_resize:
resize_shape, dims = resize_from_dims(dims, ndim_actual, model)
elif Ellipsis in dims:
# Replace ... with None entries to match the actual dimensionality.
dims = (*dims[:-1], *[None] * ndim_actual)[:ndim_actual]
elif observed is not None:
resize_shape, observed = resize_from_observed(observed, ndim_actual)

Expand Down
58 changes: 41 additions & 17 deletions pymc/tests/test_shape_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,47 @@ def test_param_and_batch_shape_combos(
else:
raise NotImplementedError("Invalid test case parametrization.")

@pytest.mark.parametrize("ellipsis_in", ["none", "shape", "dims", "both"])
def test_simultaneous_shape_and_dims(self, ellipsis_in):
with pm.Model() as pmodel:
x = pm.ConstantData("x", [1, 2, 3], dims="ddata")

if ellipsis_in == "none":
# The shape and dims tuples correspond to each other.
# Note: No checks are performed that implied shape (x), shape and dims actually match.
y = pm.Normal("y", mu=x, shape=(2, 3), dims=("dshape", "ddata"))
assert pmodel.RV_dims["y"] == ("dshape", "ddata")
elif ellipsis_in == "shape":
y = pm.Normal("y", mu=x, shape=(2, ...), dims=("dshape", "ddata"))
assert pmodel.RV_dims["y"] == ("dshape", "ddata")
elif ellipsis_in == "dims":
y = pm.Normal("y", mu=x, shape=(2, 3), dims=("dshape", ...))
assert pmodel.RV_dims["y"] == ("dshape", None)
elif ellipsis_in == "both":
y = pm.Normal("y", mu=x, shape=(2, ...), dims=("dshape", ...))
assert pmodel.RV_dims["y"] == ("dshape", None)

assert "dshape" in pmodel.dim_lengths
assert y.eval().shape == (2, 3)

@pytest.mark.parametrize("with_dims_ellipsis", [False, True])
def test_simultaneous_size_and_dims(self, with_dims_ellipsis):
with pm.Model() as pmodel:
x = pm.ConstantData("x", [1, 2, 3], dims="ddata")
assert "ddata" in pmodel.dim_lengths

# Size does not included support dims, so this teest must use a dist with support dims.
kwargs = dict(name="y", size=2, mu=at.ones((3, 4)), cov=at.eye(4))
if with_dims_ellipsis:
y = pm.MvNormal(**kwargs, dims=("dsize", ...))
assert pmodel.RV_dims["y"] == ("dsize", None, None)
else:
y = pm.MvNormal(**kwargs, dims=("dsize", "ddata", "dsupport"))
assert pmodel.RV_dims["y"] == ("dsize", "ddata", "dsupport")

assert "dsize" in pmodel.dim_lengths
assert y.eval().shape == (2, 3, 4)

def test_define_dims_on_the_fly(self):
with pm.Model() as pmodel:
agedata = aesara.shared(np.array([10, 20, 30]))
Expand All @@ -312,17 +353,6 @@ def test_define_dims_on_the_fly(self):
# The change should propagate all the way through
assert effect.eval().shape == (4,)

@pytest.mark.xfail(reason="Simultaneous use of size and dims is not implemented")
def test_data_defined_size_dimension_can_register_dimname(self):
with pm.Model() as pmodel:
x = pm.ConstantData("x", [[1, 2, 3, 4]], dims=("first", "second"))
assert "first" in pmodel.dim_lengths
assert "second" in pmodel.dim_lengths
# two dimensions are implied; a "third" dimension is created
y = pm.Normal("y", mu=x, size=2, dims=("third", "first", "second"))
assert "third" in pmodel.dim_lengths
assert y.eval().shape() == (2, 1, 4)

def test_can_resize_data_defined_size(self):
with pm.Model() as pmodel:
x = pm.MutableData("x", [[1, 2, 3, 4]], dims=("first", "second"))
Expand Down Expand Up @@ -447,9 +477,3 @@ def test_lazy_flavors(self):
def test_invalid_flavors(self):
with pytest.raises(ValueError, match="Passing both"):
pm.Normal.dist(0, 1, shape=(3,), size=(3,))

with pm.Model():
with pytest.raises(ValueError, match="Passing both"):
pm.Normal("n", shape=(2,), dims=("town",))
with pytest.raises(ValueError, match="Passing both"):
pm.Normal("n", dims=("town",), size=(2,))

0 comments on commit d0f12eb

Please sign in to comment.