diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index cf000f4c79..6fd727358b 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -75,6 +75,11 @@ class Ordered(RVTransform): name = "ordered" def __init__(self, ndim_supp=0): + if ndim_supp > 1: + raise ValueError( + f"For Ordered transformation number of core dimensions" + f"(ndim_supp) must not exceed 1 but is {ndim_supp}" + ) self.ndim_supp = ndim_supp def backward(self, value, *inputs): @@ -105,6 +110,11 @@ class SumTo1(RVTransform): name = "sumto1" def __init__(self, ndim_supp=0): + if ndim_supp > 1: + raise ValueError( + f"For SumTo1 transformation number of core dimensions" + f"(ndim_supp) must not exceed 1 but is {ndim_supp}" + ) self.ndim_supp = ndim_supp def backward(self, value, *inputs): @@ -352,11 +362,6 @@ def extend_axis_rev(array, axis): Instantiation of :class:`pymc.distributions.transforms.Ordered` for use in the ``transform`` argument of a multivariate random variable.""" -ordered = Ordered(ndim_supp=1) -ordered.__doc__ = """ -Instantiation of :class:`pymc.distributions.transforms.Ordered` -for use in the ``transform`` argument. """ - log = LogTransform() log.__doc__ = """ @@ -373,11 +378,6 @@ def extend_axis_rev(array, axis): Instantiation of :class:`pymc.distributions.transforms.SumTo1` for use in the ``transform`` argument of a multivariate random variable.""" -sum_to_1 = SumTo1(ndim_supp=1) -sum_to_1.__doc__ = """ -Instantiation of :class:`pymc.distributions.transforms.SumTo1` -for use in the ``transform`` argument of a random variable.""" - circular = CircularTransform() circular.__doc__ = """ Instantiation of :class:`aeppl.transforms.CircularTransform` diff --git a/pymc/tests/distributions/test_transform.py b/pymc/tests/distributions/test_transform.py index 54c9e695d5..59db3f398b 100644 --- a/pymc/tests/distributions/test_transform.py +++ b/pymc/tests/distributions/test_transform.py @@ -13,6 +13,8 @@ # limitations under the License. +from typing import Union + import aesara import aesara.tensor as at import numpy as np @@ -139,10 +141,12 @@ def test_simplex_accuracy(): def test_sum_to_1(): - check_vector_transform(tr.sum_to_1, Simplex(2)) - check_vector_transform(tr.sum_to_1, Simplex(4)) + check_vector_transform(tr.univariate_sum_to_1, Simplex(2)) + check_vector_transform(tr.univariate_sum_to_1, Simplex(4)) - check_jacobian_det(tr.sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1]) + check_jacobian_det( + tr.univariate_sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1] + ) def test_log(): @@ -241,28 +245,30 @@ def test_circular(): def test_ordered(): - check_vector_transform(tr.ordered, SortedVector(6)) + check_vector_transform(tr.univariate_ordered, SortedVector(6)) - check_jacobian_det(tr.ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False) + check_jacobian_det( + tr.univariate_ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False + ) - vals = get_values(tr.ordered, Vector(R, 3), at.dvector, np.zeros(3)) + vals = get_values(tr.univariate_ordered, Vector(R, 3), at.dvector, np.zeros(3)) close_to_logical(np.diff(vals) >= 0, True, tol) def test_chain_values(): - chain_tranf = tr.Chain([tr.logodds, tr.ordered]) + chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered]) vals = get_values(chain_tranf, Vector(R, 5), at.dvector, np.zeros(5)) close_to_logical(np.diff(vals) >= 0, True, tol) def test_chain_vector_transform(): - chain_tranf = tr.Chain([tr.logodds, tr.ordered]) + chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered]) check_vector_transform(chain_tranf, UnitSortedVector(3)) @pytest.mark.xfail(reason="Fails due to precision issue. Values just close to expected.") def test_chain_jacob_det(): - chain_tranf = tr.Chain([tr.logodds, tr.ordered]) + chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered]) check_jacobian_det(chain_tranf, Vector(R, 4), at.dvector, np.zeros(4), elemwise=False) @@ -327,7 +333,14 @@ def check_vectortransform_elementwise_logp(self, model): jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs) # Original distribution is univariate if x.owner.op.ndim_supp == 0: - assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1) + tr_steps = getattr(transform, "transform_list", [transform]) + transform_keeps_dim = any( + [isinstance(ts, Union[tr.SumTo1, tr.Ordered]) for ts in tr_steps] + ) + if transform_keeps_dim: + assert model.logp(x, sum=False)[0].ndim == x.ndim == jacob_det.ndim + else: + assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1) # Original distribution is multivariate else: assert model.logp(x, sum=False)[0].ndim == (x.ndim - 1) == jacob_det.ndim @@ -449,7 +462,7 @@ def test_normal_ordered(self): {"mu": 0.0, "sigma": 1.0}, size=3, initval=np.asarray([-1.0, 1.0, 4.0]), - transform=tr.ordered, + transform=tr.univariate_ordered, ) self.check_vectortransform_elementwise_logp(model) @@ -467,7 +480,7 @@ def test_half_normal_ordered(self, sigma, size): {"sigma": sigma}, size=size, initval=initval, - transform=tr.Chain([tr.log, tr.ordered]), + transform=tr.Chain([tr.log, tr.univariate_ordered]), ) self.check_vectortransform_elementwise_logp(model) @@ -479,7 +492,7 @@ def test_exponential_ordered(self, lam, size): {"lam": lam}, size=size, initval=initval, - transform=tr.Chain([tr.log, tr.ordered]), + transform=tr.Chain([tr.log, tr.univariate_ordered]), ) self.check_vectortransform_elementwise_logp(model) @@ -501,7 +514,7 @@ def test_beta_ordered(self, a, b, size): {"alpha": a, "beta": b}, size=size, initval=initval, - transform=tr.Chain([tr.logodds, tr.ordered]), + transform=tr.Chain([tr.logodds, tr.univariate_ordered]), ) self.check_vectortransform_elementwise_logp(model) @@ -524,7 +537,7 @@ def transform_params(*inputs): {"lower": lower, "upper": upper}, size=size, initval=initval, - transform=tr.Chain([interval, tr.ordered]), + transform=tr.Chain([interval, tr.univariate_ordered]), ) self.check_vectortransform_elementwise_logp(model) @@ -536,7 +549,7 @@ def test_vonmises_ordered(self, mu, kappa, size): {"mu": mu, "kappa": kappa}, size=size, initval=initval, - transform=tr.Chain([tr.circular, tr.ordered]), + transform=tr.Chain([tr.circular, tr.univariate_ordered]), ) self.check_vectortransform_elementwise_logp(model) @@ -545,7 +558,7 @@ def test_vonmises_ordered(self, mu, kappa, size): [ (0.0, 1.0, (2,), tr.simplex), (0.5, 5.5, (2, 3), tr.simplex), - (np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.sum_to_1, tr.logodds])), + (np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.univariate_sum_to_1, tr.logodds])), ], ) def test_uniform_other(self, lower, upper, size, transform): @@ -573,7 +586,7 @@ def test_mvnormal_ordered(self, mu, cov, size, shape): {"mu": mu, "cov": cov}, size=size, initval=initval, - transform=tr.ordered, + transform=tr.multivariate_ordered, ) self.check_vectortransform_elementwise_logp(model)