From d0f12eb4e7b7224e229ccfb5985c1a1caadea7bd Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sat, 8 Jan 2022 17:57:32 +0100 Subject: [PATCH] Restore support for passing `dims` alongside `shape` or `size` Closes #4656 --- pymc/distributions/distribution.py | 18 ++++------ pymc/tests/test_shape_handling.py | 58 +++++++++++++++++++++--------- 2 files changed, 48 insertions(+), 28 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 348d729c155..1f720c00c7a 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -212,16 +212,6 @@ def __new__( if rng is None: rng = model.next_rng() - if dims is not None and "shape" in kwargs: - raise ValueError( - f"Passing both `dims` ({dims}) and `shape` ({kwargs['shape']}) is not supported!" - ) - if dims is not None and "size" in kwargs: - raise ValueError( - f"Passing both `dims` ({dims}) and `size` ({kwargs['size']}) is not supported!" - ) - dims = convert_dims(dims) - # Create the RV without dims information, because that's not something tracked at the Aesara level. # If necessary we'll later replicate to a different size implied by already known dims. rv_out = cls.dist(*args, rng=rng, **kwargs) @@ -230,8 +220,14 @@ def __new__( # `dims` are only available with this API, because `.dist()` can be used # without a modelcontext and dims are not tracked at the Aesara level. + dims = convert_dims(dims) + dims_can_resize = kwargs.get("shape", None) is None and kwargs.get("size", None) is None if dims is not None: - resize_shape, dims = resize_from_dims(dims, ndim_actual, model) + if dims_can_resize: + resize_shape, dims = resize_from_dims(dims, ndim_actual, model) + elif Ellipsis in dims: + # Replace ... with None entries to match the actual dimensionality. + dims = (*dims[:-1], *[None] * ndim_actual)[:ndim_actual] elif observed is not None: resize_shape, observed = resize_from_observed(observed, ndim_actual) diff --git a/pymc/tests/test_shape_handling.py b/pymc/tests/test_shape_handling.py index 45ef7c594ed..40e907a6737 100644 --- a/pymc/tests/test_shape_handling.py +++ b/pymc/tests/test_shape_handling.py @@ -293,6 +293,47 @@ def test_param_and_batch_shape_combos( else: raise NotImplementedError("Invalid test case parametrization.") + @pytest.mark.parametrize("ellipsis_in", ["none", "shape", "dims", "both"]) + def test_simultaneous_shape_and_dims(self, ellipsis_in): + with pm.Model() as pmodel: + x = pm.ConstantData("x", [1, 2, 3], dims="ddata") + + if ellipsis_in == "none": + # The shape and dims tuples correspond to each other. + # Note: No checks are performed that implied shape (x), shape and dims actually match. + y = pm.Normal("y", mu=x, shape=(2, 3), dims=("dshape", "ddata")) + assert pmodel.RV_dims["y"] == ("dshape", "ddata") + elif ellipsis_in == "shape": + y = pm.Normal("y", mu=x, shape=(2, ...), dims=("dshape", "ddata")) + assert pmodel.RV_dims["y"] == ("dshape", "ddata") + elif ellipsis_in == "dims": + y = pm.Normal("y", mu=x, shape=(2, 3), dims=("dshape", ...)) + assert pmodel.RV_dims["y"] == ("dshape", None) + elif ellipsis_in == "both": + y = pm.Normal("y", mu=x, shape=(2, ...), dims=("dshape", ...)) + assert pmodel.RV_dims["y"] == ("dshape", None) + + assert "dshape" in pmodel.dim_lengths + assert y.eval().shape == (2, 3) + + @pytest.mark.parametrize("with_dims_ellipsis", [False, True]) + def test_simultaneous_size_and_dims(self, with_dims_ellipsis): + with pm.Model() as pmodel: + x = pm.ConstantData("x", [1, 2, 3], dims="ddata") + assert "ddata" in pmodel.dim_lengths + + # Size does not included support dims, so this teest must use a dist with support dims. + kwargs = dict(name="y", size=2, mu=at.ones((3, 4)), cov=at.eye(4)) + if with_dims_ellipsis: + y = pm.MvNormal(**kwargs, dims=("dsize", ...)) + assert pmodel.RV_dims["y"] == ("dsize", None, None) + else: + y = pm.MvNormal(**kwargs, dims=("dsize", "ddata", "dsupport")) + assert pmodel.RV_dims["y"] == ("dsize", "ddata", "dsupport") + + assert "dsize" in pmodel.dim_lengths + assert y.eval().shape == (2, 3, 4) + def test_define_dims_on_the_fly(self): with pm.Model() as pmodel: agedata = aesara.shared(np.array([10, 20, 30])) @@ -312,17 +353,6 @@ def test_define_dims_on_the_fly(self): # The change should propagate all the way through assert effect.eval().shape == (4,) - @pytest.mark.xfail(reason="Simultaneous use of size and dims is not implemented") - def test_data_defined_size_dimension_can_register_dimname(self): - with pm.Model() as pmodel: - x = pm.ConstantData("x", [[1, 2, 3, 4]], dims=("first", "second")) - assert "first" in pmodel.dim_lengths - assert "second" in pmodel.dim_lengths - # two dimensions are implied; a "third" dimension is created - y = pm.Normal("y", mu=x, size=2, dims=("third", "first", "second")) - assert "third" in pmodel.dim_lengths - assert y.eval().shape() == (2, 1, 4) - def test_can_resize_data_defined_size(self): with pm.Model() as pmodel: x = pm.MutableData("x", [[1, 2, 3, 4]], dims=("first", "second")) @@ -447,9 +477,3 @@ def test_lazy_flavors(self): def test_invalid_flavors(self): with pytest.raises(ValueError, match="Passing both"): pm.Normal.dist(0, 1, shape=(3,), size=(3,)) - - with pm.Model(): - with pytest.raises(ValueError, match="Passing both"): - pm.Normal("n", shape=(2,), dims=("town",)) - with pytest.raises(ValueError, match="Passing both"): - pm.Normal("n", dims=("town",), size=(2,))