Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ordering Transformation for batched dimensions #6255

Merged
merged 13 commits into from
Nov 22, 2022
55 changes: 45 additions & 10 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@
"logodds",
"Interval",
"log_exp_m1",
"ordered",
"univariate_ordered",
"multivariate_ordered",
"log",
"sum_to_1",
"univariate_sum_to_1",
"multivariate_sum_to_1",
"circular",
"CholeskyCovPacked",
"Chain",
Expand Down Expand Up @@ -74,6 +76,14 @@ def log_jac_det(self, value, *inputs):
class Ordered(RVTransform):
name = "ordered"

def __init__(self, ndim_supp=0):
if ndim_supp > 1:
raise ValueError(
f"For Ordered transformation number of core dimensions"
f"(ndim_supp) must not exceed 1 but is {ndim_supp}"
)
self.ndim_supp = ndim_supp

def backward(self, value, *inputs):
x = at.zeros(value.shape)
x = at.inc_subtensor(x[..., 0], value[..., 0])
Expand All @@ -87,7 +97,10 @@ def forward(self, value, *inputs):
return y

def log_jac_det(self, value, *inputs):
return at.sum(value[..., 1:], axis=-1)
if self.ndim_supp == 0:
return at.sum(value[..., 1:], axis=-1, keepdims=True)
else:
return at.sum(value[..., 1:], axis=-1)


class SumTo1(RVTransform):
Expand All @@ -98,6 +111,14 @@ class SumTo1(RVTransform):

name = "sumto1"

def __init__(self, ndim_supp=0):
if ndim_supp > 1:
raise ValueError(
f"For SumTo1 transformation number of core dimensions"
f"(ndim_supp) must not exceed 1 but is {ndim_supp}"
)
self.ndim_supp = ndim_supp

def backward(self, value, *inputs):
remaining = 1 - at.sum(value[..., :], axis=-1, keepdims=True)
return at.concatenate([value[..., :], remaining], axis=-1)
Expand All @@ -107,7 +128,10 @@ def forward(self, value, *inputs):

def log_jac_det(self, value, *inputs):
y = at.zeros(value.shape)
return at.sum(y, axis=-1)
if self.ndim_supp == 0:
return at.sum(y, axis=-1, keepdims=True)
else:
return at.sum(y, axis=-1)


class CholeskyCovPacked(RVTransform):
Expand Down Expand Up @@ -330,20 +354,31 @@ def extend_axis_rev(array, axis):
Instantiation of :class:`pymc.distributions.transforms.LogExpM1`
for use in the ``transform`` argument of a random variable."""

ordered = Ordered()
ordered.__doc__ = """
univariate_ordered = Ordered(ndim_supp=0)
univariate_ordered.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.Ordered`
for use in the ``transform`` argument of a random variable."""
for use in the ``transform`` argument of a univariate random variable."""

multivariate_ordered = Ordered(ndim_supp=1)
multivariate_ordered.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.Ordered`
for use in the ``transform`` argument of a multivariate random variable."""


log = LogTransform()
log.__doc__ = """
Instantiation of :class:`aeppl.transforms.LogTransform`
for use in the ``transform`` argument of a random variable."""

sum_to_1 = SumTo1()
sum_to_1.__doc__ = """
univariate_sum_to_1 = SumTo1(ndim_supp=0)
univariate_sum_to_1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
for use in the ``transform`` argument of a random variable."""
for use in the ``transform`` argument of a univariate random variable."""

multivariate_sum_to_1 = SumTo1(ndim_supp=1)
multivariate_sum_to_1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
for use in the ``transform`` argument of a multivariate random variable."""

circular = CircularTransform()
circular.__doc__ = """
Expand Down
81 changes: 63 additions & 18 deletions pymc/tests/distributions/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.


from typing import Union

import aesara
import aesara.tensor as at
import numpy as np
Expand Down Expand Up @@ -139,10 +141,12 @@ def test_simplex_accuracy():


def test_sum_to_1():
check_vector_transform(tr.sum_to_1, Simplex(2))
check_vector_transform(tr.sum_to_1, Simplex(4))
check_vector_transform(tr.univariate_sum_to_1, Simplex(2))
check_vector_transform(tr.univariate_sum_to_1, Simplex(4))

check_jacobian_det(tr.sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1])
check_jacobian_det(
tr.univariate_sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1]
)


def test_log():
Expand Down Expand Up @@ -241,28 +245,30 @@ def test_circular():


def test_ordered():
check_vector_transform(tr.ordered, SortedVector(6))
check_vector_transform(tr.univariate_ordered, SortedVector(6))

check_jacobian_det(tr.ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False)
check_jacobian_det(
tr.univariate_ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False
)

vals = get_values(tr.ordered, Vector(R, 3), at.dvector, np.zeros(3))
vals = get_values(tr.univariate_ordered, Vector(R, 3), at.dvector, np.zeros(3))
close_to_logical(np.diff(vals) >= 0, True, tol)


def test_chain_values():
chain_tranf = tr.Chain([tr.logodds, tr.ordered])
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
vals = get_values(chain_tranf, Vector(R, 5), at.dvector, np.zeros(5))
close_to_logical(np.diff(vals) >= 0, True, tol)


def test_chain_vector_transform():
chain_tranf = tr.Chain([tr.logodds, tr.ordered])
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
check_vector_transform(chain_tranf, UnitSortedVector(3))


@pytest.mark.xfail(reason="Fails due to precision issue. Values just close to expected.")
def test_chain_jacob_det():
chain_tranf = tr.Chain([tr.logodds, tr.ordered])
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
check_jacobian_det(chain_tranf, Vector(R, 4), at.dvector, np.zeros(4), elemwise=False)


Expand Down Expand Up @@ -327,7 +333,14 @@ def check_vectortransform_elementwise_logp(self, model):
jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs)
# Original distribution is univariate
if x.owner.op.ndim_supp == 0:
assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1)
tr_steps = getattr(transform, "transform_list", [transform])
transform_keeps_dim = any(
[isinstance(ts, Union[tr.SumTo1, tr.Ordered]) for ts in tr_steps]
)
if transform_keeps_dim:
assert model.logp(x, sum=False)[0].ndim == x.ndim == jacob_det.ndim
else:
assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1)
# Original distribution is multivariate
else:
assert model.logp(x, sum=False)[0].ndim == (x.ndim - 1) == jacob_det.ndim
Expand Down Expand Up @@ -449,7 +462,7 @@ def test_normal_ordered(self):
{"mu": 0.0, "sigma": 1.0},
size=3,
initval=np.asarray([-1.0, 1.0, 4.0]),
transform=tr.ordered,
transform=tr.univariate_ordered,
)
self.check_vectortransform_elementwise_logp(model)

Expand All @@ -467,7 +480,7 @@ def test_half_normal_ordered(self, sigma, size):
{"sigma": sigma},
size=size,
initval=initval,
transform=tr.Chain([tr.log, tr.ordered]),
transform=tr.Chain([tr.log, tr.univariate_ordered]),
)
self.check_vectortransform_elementwise_logp(model)

Expand All @@ -479,7 +492,7 @@ def test_exponential_ordered(self, lam, size):
{"lam": lam},
size=size,
initval=initval,
transform=tr.Chain([tr.log, tr.ordered]),
transform=tr.Chain([tr.log, tr.univariate_ordered]),
)
self.check_vectortransform_elementwise_logp(model)

Expand All @@ -501,7 +514,7 @@ def test_beta_ordered(self, a, b, size):
{"alpha": a, "beta": b},
size=size,
initval=initval,
transform=tr.Chain([tr.logodds, tr.ordered]),
transform=tr.Chain([tr.logodds, tr.univariate_ordered]),
)
self.check_vectortransform_elementwise_logp(model)

Expand All @@ -524,7 +537,7 @@ def transform_params(*inputs):
{"lower": lower, "upper": upper},
size=size,
initval=initval,
transform=tr.Chain([interval, tr.ordered]),
transform=tr.Chain([interval, tr.univariate_ordered]),
)
self.check_vectortransform_elementwise_logp(model)

Expand All @@ -536,7 +549,7 @@ def test_vonmises_ordered(self, mu, kappa, size):
{"mu": mu, "kappa": kappa},
size=size,
initval=initval,
transform=tr.Chain([tr.circular, tr.ordered]),
transform=tr.Chain([tr.circular, tr.univariate_ordered]),
)
self.check_vectortransform_elementwise_logp(model)

Expand All @@ -545,7 +558,7 @@ def test_vonmises_ordered(self, mu, kappa, size):
[
(0.0, 1.0, (2,), tr.simplex),
(0.5, 5.5, (2, 3), tr.simplex),
(np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.sum_to_1, tr.logodds])),
(np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.univariate_sum_to_1, tr.logodds])),
],
)
def test_uniform_other(self, lower, upper, size, transform):
Expand All @@ -569,7 +582,11 @@ def test_uniform_other(self, lower, upper, size, transform):
def test_mvnormal_ordered(self, mu, cov, size, shape):
initval = np.sort(np.random.randn(*shape))
model = self.build_model(
pm.MvNormal, {"mu": mu, "cov": cov}, size=size, initval=initval, transform=tr.ordered
pm.MvNormal,
{"mu": mu, "cov": cov},
size=size,
initval=initval,
transform=tr.multivariate_ordered,
)
self.check_vectortransform_elementwise_logp(model)

Expand Down Expand Up @@ -598,3 +615,31 @@ def test_discrete_trafo():
with pytest.raises(ValueError) as err:
pm.Binomial("a", n=5, p=0.5, transform="log")
err.match("Transformations for discrete distributions")


def test_transforms_ordered():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a more direct name?

Suggested change
def test_transforms_ordered():
def test_2d_univariate_ordered_():

with pm.Model() as model:
pm.Normal(
"x_univariate",
mu=[-3, -1, 1, 2],
sigma=1,
size=(10, 4),
transform=tr.univariate_ordered,
)

log_prob = model.point_logps()
Copy link
Member

@ricardoV94 ricardoV94 Nov 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this test a bit more readable I would suggest including an equivalent x_1d = pm.Normal(..., shape=(4,)) in the model and comparing the elemwise logp of that is the same as each copy of the 2D (10, 4) one that exists now.

You can do that via m.compile_logp(sum=False)({"x_1d_ordered__": np.zeros((4,)), x_2d_ordered__": np.zeros(10, 4)}) and then asserting closeness across the last axis.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the sum_to_1 transform the logps are not the same. Is this expected? Even if the shape is the same like here:

def test_2d_univariate_sum_to_1():
    with pm.Model() as model:
        x_1d = pm.Normal(
            "x_1d",
            mu=[-3,-1,1,2],
            sigma=1,
            size=(10,4),
            transform=tr.univariate_sum_to_1,
        )
        x_2d = pm.Normal(
            "x_2d",
            mu=[-3, -1, 1, 2],
            sigma=1,
            size=(10, 4),
            transform=tr.univariate_sum_to_1,
        )

    log_p = model.compile_logp(sum=False)({"x_1d_sumto1__":np.ones((10,3))*0.25,"x_2d_sumto1__":np.zeros((10,3))*0.25})
    np.testing.assert_allclose(log_p[0],log_p[1])

Fails with:

./pymc/tests/distributions/test_transform.py::test_2d_univariate_sum_to_1 Failed: [undefined]AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 40 / 40 (100%)
Max absolute difference: 1.03125
Max relative difference: 0.72677567
 x: array([[-6.200189, -1.700189, -1.200189, -2.450189],
       [-6.200189, -1.700189, -1.200189, -2.450189],
       [-6.200189, -1.700189, -1.200189, -2.450189],...
 y: array([[-5.418939, -1.418939, -1.418939, -1.418939],
       [-5.418939, -1.418939, -1.418939, -1.418939],
       [-5.418939, -1.418939, -1.418939, -1.418939],...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok funny. This is not the case if I use np.zeros

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR looks good otherwise, should we investigate the SumTo1?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this when comparing np.zeros(), vs np.ones()? I don't expect those to be equivalent...

Copy link
Contributor Author

@TimOliverMaier TimOliverMaier Nov 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh boy 😆 . Maybe it was just that. I thought I was comparing np.ones() to np.ones(). will check again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This was my mistake, sorry! Checked it properly with np.ones() and it passes.

np.testing.assert_allclose(list(log_prob.values()), np.array([18.69]))


def test_transforms_sumto1():
with pm.Model() as model:
pm.Normal(
"x",
mu=[-3, -1, 1, 2],
sigma=1,
size=(10, 4),
transform=tr.univariate_sum_to_1,
)

log_prob = model.point_logps()
np.testing.assert_allclose(list(log_prob.values()), np.array([-56.76]))