From 311bf9d1bff321ed0c8ca2f73701bb56a22bc7be Mon Sep 17 00:00:00 2001 From: "Tim Maier (Ubuntu Desktop)" Date: Mon, 31 Oct 2022 12:24:41 +0100 Subject: [PATCH] keep transforms.ordered for backward compatibility --- pymc/distributions/transforms.py | 18 ++++++++++++++++-- pymc/tests/distributions/test_transform.py | 10 +++++----- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 6a308469b4d..cf000f4c79b 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -352,14 +352,28 @@ def extend_axis_rev(array, axis): Instantiation of :class:`pymc.distributions.transforms.Ordered` for use in the ``transform`` argument of a multivariate random variable.""" -ordered = univariate_ordered +ordered = Ordered(ndim_supp=1) +ordered.__doc__ = """ +Instantiation of :class:`pymc.distributions.transforms.Ordered` +for use in the ``transform`` argument. """ + log = LogTransform() log.__doc__ = """ Instantiation of :class:`aeppl.transforms.LogTransform` for use in the ``transform`` argument of a random variable.""" -sum_to_1 = SumTo1() +univariate_sum_to_1 = SumTo1(ndim_supp=0) +univariate_sum_to_1.__doc__ = """ +Instantiation of :class:`pymc.distributions.transforms.SumTo1` +for use in the ``transform`` argument of a univariate random variable.""" + +multivariate_sum_to_1 = SumTo1(ndim_supp=1) +multivariate_sum_to_1.__doc__ = """ +Instantiation of :class:`pymc.distributions.transforms.SumTo1` +for use in the ``transform`` argument of a multivariate random variable.""" + +sum_to_1 = SumTo1(ndim_supp=1) sum_to_1.__doc__ = """ Instantiation of :class:`pymc.distributions.transforms.SumTo1` for use in the ``transform`` argument of a random variable.""" diff --git a/pymc/tests/distributions/test_transform.py b/pymc/tests/distributions/test_transform.py index f26b8562f05..bffadbca260 100644 --- a/pymc/tests/distributions/test_transform.py +++ b/pymc/tests/distributions/test_transform.py @@ -310,7 +310,7 @@ def check_vectortransform_elementwise_logp(self, model): jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs) # Original distribution is univariate if x.owner.op.ndim_supp == 0: - assert joint_logp(x, sum=False)[0].ndim == x.ndim == jacob_det.ndim + assert joint_logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1) # Original distribution is multivariate else: assert joint_logp(x, sum=False)[0].ndim == (x.ndim - 1) == jacob_det.ndim @@ -539,7 +539,7 @@ def test_mvnormal_ordered(self, mu, cov, size, shape): {"mu": mu, "cov": cov}, size=size, initval=initval, - transform=tr.multivariate_ordered, + transform=tr.ordered, ) self.check_vectortransform_elementwise_logp(model) @@ -573,11 +573,11 @@ def test_discrete_trafo(): def test_transforms_ordered(): with pm.Model() as model: pm.Normal( - "x", + "x_univariate", mu=[-3, -1, 1, 2], sigma=1, size=(10, 4), - transform=pm.distributions.transforms.ordered, + transform=tr.univariate_ordered, ) log_prob = model.point_logps() @@ -591,7 +591,7 @@ def test_transforms_sumto1(): mu=[-3, -1, 1, 2], sigma=1, size=(10, 4), - transform=pm.distributions.transforms.sum_to_1, + transform=tr.univariate_sum_to_1, ) log_prob = model.point_logps()