Skip to content

Commit

Permalink
keep transforms.ordered for backward compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
TimOliverMaier committed Oct 31, 2022
1 parent 27bb86e commit 311bf9d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
18 changes: 16 additions & 2 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,14 +352,28 @@ def extend_axis_rev(array, axis):
Instantiation of :class:`pymc.distributions.transforms.Ordered`
for use in the ``transform`` argument of a multivariate random variable."""

ordered = univariate_ordered
ordered = Ordered(ndim_supp=1)
ordered.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.Ordered`
for use in the ``transform`` argument. """


log = LogTransform()
log.__doc__ = """
Instantiation of :class:`aeppl.transforms.LogTransform`
for use in the ``transform`` argument of a random variable."""

sum_to_1 = SumTo1()
univariate_sum_to_1 = SumTo1(ndim_supp=0)
univariate_sum_to_1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
for use in the ``transform`` argument of a univariate random variable."""

multivariate_sum_to_1 = SumTo1(ndim_supp=1)
multivariate_sum_to_1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
for use in the ``transform`` argument of a multivariate random variable."""

sum_to_1 = SumTo1(ndim_supp=1)
sum_to_1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
for use in the ``transform`` argument of a random variable."""
Expand Down
10 changes: 5 additions & 5 deletions pymc/tests/distributions/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def check_vectortransform_elementwise_logp(self, model):
jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs)
# Original distribution is univariate
if x.owner.op.ndim_supp == 0:
assert joint_logp(x, sum=False)[0].ndim == x.ndim == jacob_det.ndim
assert joint_logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1)
# Original distribution is multivariate
else:
assert joint_logp(x, sum=False)[0].ndim == (x.ndim - 1) == jacob_det.ndim
Expand Down Expand Up @@ -539,7 +539,7 @@ def test_mvnormal_ordered(self, mu, cov, size, shape):
{"mu": mu, "cov": cov},
size=size,
initval=initval,
transform=tr.multivariate_ordered,
transform=tr.ordered,
)
self.check_vectortransform_elementwise_logp(model)

Expand Down Expand Up @@ -573,11 +573,11 @@ def test_discrete_trafo():
def test_transforms_ordered():
with pm.Model() as model:
pm.Normal(
"x",
"x_univariate",
mu=[-3, -1, 1, 2],
sigma=1,
size=(10, 4),
transform=pm.distributions.transforms.ordered,
transform=tr.univariate_ordered,
)

log_prob = model.point_logps()
Expand All @@ -591,7 +591,7 @@ def test_transforms_sumto1():
mu=[-3, -1, 1, 2],
sigma=1,
size=(10, 4),
transform=pm.distributions.transforms.sum_to_1,
transform=tr.univariate_sum_to_1,
)

log_prob = model.point_logps()
Expand Down

0 comments on commit 311bf9d

Please sign in to comment.