Skip to content

Commit

Permalink
pre-commit run
Browse files Browse the repository at this point in the history
  • Loading branch information
TimOliverMaier committed Nov 19, 2022
1 parent 9d279db commit 7515c73
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 28 deletions.
20 changes: 10 additions & 10 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ class Ordered(RVTransform):
name = "ordered"

def __init__(self, ndim_supp=0):
if ndim_supp > 1:
raise ValueError(
f"For Ordered transformation number of core dimensions"
f"(ndim_supp) must not exceed 1 but is {ndim_supp}"
)
self.ndim_supp = ndim_supp

def backward(self, value, *inputs):
Expand Down Expand Up @@ -105,6 +110,11 @@ class SumTo1(RVTransform):
name = "sumto1"

def __init__(self, ndim_supp=0):
if ndim_supp > 1:
raise ValueError(
f"For SumTo1 transformation number of core dimensions"
f"(ndim_supp) must not exceed 1 but is {ndim_supp}"
)
self.ndim_supp = ndim_supp

def backward(self, value, *inputs):
Expand Down Expand Up @@ -352,11 +362,6 @@ def extend_axis_rev(array, axis):
Instantiation of :class:`pymc.distributions.transforms.Ordered`
for use in the ``transform`` argument of a multivariate random variable."""

ordered = Ordered(ndim_supp=1)
ordered.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.Ordered`
for use in the ``transform`` argument. """


log = LogTransform()
log.__doc__ = """
Expand All @@ -373,11 +378,6 @@ def extend_axis_rev(array, axis):
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
for use in the ``transform`` argument of a multivariate random variable."""

sum_to_1 = SumTo1(ndim_supp=1)
sum_to_1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
for use in the ``transform`` argument of a random variable."""

circular = CircularTransform()
circular.__doc__ = """
Instantiation of :class:`aeppl.transforms.CircularTransform`
Expand Down
49 changes: 31 additions & 18 deletions pymc/tests/distributions/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.


from typing import Union

import aesara
import aesara.tensor as at
import numpy as np
Expand Down Expand Up @@ -139,10 +141,12 @@ def test_simplex_accuracy():


def test_sum_to_1():
check_vector_transform(tr.sum_to_1, Simplex(2))
check_vector_transform(tr.sum_to_1, Simplex(4))
check_vector_transform(tr.univariate_sum_to_1, Simplex(2))
check_vector_transform(tr.univariate_sum_to_1, Simplex(4))

check_jacobian_det(tr.sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1])
check_jacobian_det(
tr.univariate_sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1]
)


def test_log():
Expand Down Expand Up @@ -241,28 +245,30 @@ def test_circular():


def test_ordered():
check_vector_transform(tr.ordered, SortedVector(6))
check_vector_transform(tr.univariate_ordered, SortedVector(6))

check_jacobian_det(tr.ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False)
check_jacobian_det(
tr.univariate_ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False
)

vals = get_values(tr.ordered, Vector(R, 3), at.dvector, np.zeros(3))
vals = get_values(tr.univariate_ordered, Vector(R, 3), at.dvector, np.zeros(3))
close_to_logical(np.diff(vals) >= 0, True, tol)


def test_chain_values():
chain_tranf = tr.Chain([tr.logodds, tr.ordered])
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
vals = get_values(chain_tranf, Vector(R, 5), at.dvector, np.zeros(5))
close_to_logical(np.diff(vals) >= 0, True, tol)


def test_chain_vector_transform():
chain_tranf = tr.Chain([tr.logodds, tr.ordered])
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
check_vector_transform(chain_tranf, UnitSortedVector(3))


@pytest.mark.xfail(reason="Fails due to precision issue. Values just close to expected.")
def test_chain_jacob_det():
chain_tranf = tr.Chain([tr.logodds, tr.ordered])
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
check_jacobian_det(chain_tranf, Vector(R, 4), at.dvector, np.zeros(4), elemwise=False)


Expand Down Expand Up @@ -327,7 +333,14 @@ def check_vectortransform_elementwise_logp(self, model):
jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs)
# Original distribution is univariate
if x.owner.op.ndim_supp == 0:
assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1)
tr_steps = getattr(transform, "transform_list", [transform])
transform_keeps_dim = any(
[isinstance(ts, Union[tr.SumTo1, tr.Ordered]) for ts in tr_steps]
)
if transform_keeps_dim:
assert model.logp(x, sum=False)[0].ndim == x.ndim == jacob_det.ndim
else:
assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1)
# Original distribution is multivariate
else:
assert model.logp(x, sum=False)[0].ndim == (x.ndim - 1) == jacob_det.ndim
Expand Down Expand Up @@ -449,7 +462,7 @@ def test_normal_ordered(self):
{"mu": 0.0, "sigma": 1.0},
size=3,
initval=np.asarray([-1.0, 1.0, 4.0]),
transform=tr.ordered,
transform=tr.univariate_ordered,
)
self.check_vectortransform_elementwise_logp(model)

Expand All @@ -467,7 +480,7 @@ def test_half_normal_ordered(self, sigma, size):
{"sigma": sigma},
size=size,
initval=initval,
transform=tr.Chain([tr.log, tr.ordered]),
transform=tr.Chain([tr.log, tr.univariate_ordered]),
)
self.check_vectortransform_elementwise_logp(model)

Expand All @@ -479,7 +492,7 @@ def test_exponential_ordered(self, lam, size):
{"lam": lam},
size=size,
initval=initval,
transform=tr.Chain([tr.log, tr.ordered]),
transform=tr.Chain([tr.log, tr.univariate_ordered]),
)
self.check_vectortransform_elementwise_logp(model)

Expand All @@ -501,7 +514,7 @@ def test_beta_ordered(self, a, b, size):
{"alpha": a, "beta": b},
size=size,
initval=initval,
transform=tr.Chain([tr.logodds, tr.ordered]),
transform=tr.Chain([tr.logodds, tr.univariate_ordered]),
)
self.check_vectortransform_elementwise_logp(model)

Expand All @@ -524,7 +537,7 @@ def transform_params(*inputs):
{"lower": lower, "upper": upper},
size=size,
initval=initval,
transform=tr.Chain([interval, tr.ordered]),
transform=tr.Chain([interval, tr.univariate_ordered]),
)
self.check_vectortransform_elementwise_logp(model)

Expand All @@ -536,7 +549,7 @@ def test_vonmises_ordered(self, mu, kappa, size):
{"mu": mu, "kappa": kappa},
size=size,
initval=initval,
transform=tr.Chain([tr.circular, tr.ordered]),
transform=tr.Chain([tr.circular, tr.univariate_ordered]),
)
self.check_vectortransform_elementwise_logp(model)

Expand All @@ -545,7 +558,7 @@ def test_vonmises_ordered(self, mu, kappa, size):
[
(0.0, 1.0, (2,), tr.simplex),
(0.5, 5.5, (2, 3), tr.simplex),
(np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.sum_to_1, tr.logodds])),
(np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.univariate_sum_to_1, tr.logodds])),
],
)
def test_uniform_other(self, lower, upper, size, transform):
Expand Down Expand Up @@ -573,7 +586,7 @@ def test_mvnormal_ordered(self, mu, cov, size, shape):
{"mu": mu, "cov": cov},
size=size,
initval=initval,
transform=tr.ordered,
transform=tr.multivariate_ordered,
)
self.check_vectortransform_elementwise_logp(model)

Expand Down

0 comments on commit 7515c73

Please sign in to comment.