-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix ordering Transformation for batched dimensions #6255
Changes from 10 commits
2465a0f
c9a0470
0862134
7fb4c37
2d2bab4
0372528
7ef9b2e
9d279db
7515c73
7442474
bc8166e
8ad2c03
9daae35
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,8 @@ | |
# limitations under the License. | ||
|
||
|
||
from typing import Union | ||
|
||
import aesara | ||
import aesara.tensor as at | ||
import numpy as np | ||
|
@@ -139,10 +141,12 @@ def test_simplex_accuracy(): | |
|
||
|
||
def test_sum_to_1(): | ||
check_vector_transform(tr.sum_to_1, Simplex(2)) | ||
check_vector_transform(tr.sum_to_1, Simplex(4)) | ||
check_vector_transform(tr.univariate_sum_to_1, Simplex(2)) | ||
check_vector_transform(tr.univariate_sum_to_1, Simplex(4)) | ||
|
||
check_jacobian_det(tr.sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1]) | ||
check_jacobian_det( | ||
tr.univariate_sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1] | ||
) | ||
|
||
|
||
def test_log(): | ||
|
@@ -241,28 +245,30 @@ def test_circular(): | |
|
||
|
||
def test_ordered(): | ||
check_vector_transform(tr.ordered, SortedVector(6)) | ||
check_vector_transform(tr.univariate_ordered, SortedVector(6)) | ||
|
||
check_jacobian_det(tr.ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False) | ||
check_jacobian_det( | ||
tr.univariate_ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False | ||
) | ||
|
||
vals = get_values(tr.ordered, Vector(R, 3), at.dvector, np.zeros(3)) | ||
vals = get_values(tr.univariate_ordered, Vector(R, 3), at.dvector, np.zeros(3)) | ||
close_to_logical(np.diff(vals) >= 0, True, tol) | ||
|
||
|
||
def test_chain_values(): | ||
chain_tranf = tr.Chain([tr.logodds, tr.ordered]) | ||
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered]) | ||
vals = get_values(chain_tranf, Vector(R, 5), at.dvector, np.zeros(5)) | ||
close_to_logical(np.diff(vals) >= 0, True, tol) | ||
|
||
|
||
def test_chain_vector_transform(): | ||
chain_tranf = tr.Chain([tr.logodds, tr.ordered]) | ||
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered]) | ||
check_vector_transform(chain_tranf, UnitSortedVector(3)) | ||
|
||
|
||
@pytest.mark.xfail(reason="Fails due to precision issue. Values just close to expected.") | ||
def test_chain_jacob_det(): | ||
chain_tranf = tr.Chain([tr.logodds, tr.ordered]) | ||
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered]) | ||
check_jacobian_det(chain_tranf, Vector(R, 4), at.dvector, np.zeros(4), elemwise=False) | ||
|
||
|
||
|
@@ -327,7 +333,14 @@ def check_vectortransform_elementwise_logp(self, model): | |
jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs) | ||
# Original distribution is univariate | ||
if x.owner.op.ndim_supp == 0: | ||
assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1) | ||
tr_steps = getattr(transform, "transform_list", [transform]) | ||
transform_keeps_dim = any( | ||
[isinstance(ts, Union[tr.SumTo1, tr.Ordered]) for ts in tr_steps] | ||
) | ||
if transform_keeps_dim: | ||
assert model.logp(x, sum=False)[0].ndim == x.ndim == jacob_det.ndim | ||
else: | ||
assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1) | ||
# Original distribution is multivariate | ||
else: | ||
assert model.logp(x, sum=False)[0].ndim == (x.ndim - 1) == jacob_det.ndim | ||
|
@@ -449,7 +462,7 @@ def test_normal_ordered(self): | |
{"mu": 0.0, "sigma": 1.0}, | ||
size=3, | ||
initval=np.asarray([-1.0, 1.0, 4.0]), | ||
transform=tr.ordered, | ||
transform=tr.univariate_ordered, | ||
) | ||
self.check_vectortransform_elementwise_logp(model) | ||
|
||
|
@@ -467,7 +480,7 @@ def test_half_normal_ordered(self, sigma, size): | |
{"sigma": sigma}, | ||
size=size, | ||
initval=initval, | ||
transform=tr.Chain([tr.log, tr.ordered]), | ||
transform=tr.Chain([tr.log, tr.univariate_ordered]), | ||
) | ||
self.check_vectortransform_elementwise_logp(model) | ||
|
||
|
@@ -479,7 +492,7 @@ def test_exponential_ordered(self, lam, size): | |
{"lam": lam}, | ||
size=size, | ||
initval=initval, | ||
transform=tr.Chain([tr.log, tr.ordered]), | ||
transform=tr.Chain([tr.log, tr.univariate_ordered]), | ||
) | ||
self.check_vectortransform_elementwise_logp(model) | ||
|
||
|
@@ -501,7 +514,7 @@ def test_beta_ordered(self, a, b, size): | |
{"alpha": a, "beta": b}, | ||
size=size, | ||
initval=initval, | ||
transform=tr.Chain([tr.logodds, tr.ordered]), | ||
transform=tr.Chain([tr.logodds, tr.univariate_ordered]), | ||
) | ||
self.check_vectortransform_elementwise_logp(model) | ||
|
||
|
@@ -524,7 +537,7 @@ def transform_params(*inputs): | |
{"lower": lower, "upper": upper}, | ||
size=size, | ||
initval=initval, | ||
transform=tr.Chain([interval, tr.ordered]), | ||
transform=tr.Chain([interval, tr.univariate_ordered]), | ||
) | ||
self.check_vectortransform_elementwise_logp(model) | ||
|
||
|
@@ -536,7 +549,7 @@ def test_vonmises_ordered(self, mu, kappa, size): | |
{"mu": mu, "kappa": kappa}, | ||
size=size, | ||
initval=initval, | ||
transform=tr.Chain([tr.circular, tr.ordered]), | ||
transform=tr.Chain([tr.circular, tr.univariate_ordered]), | ||
) | ||
self.check_vectortransform_elementwise_logp(model) | ||
|
||
|
@@ -545,7 +558,7 @@ def test_vonmises_ordered(self, mu, kappa, size): | |
[ | ||
(0.0, 1.0, (2,), tr.simplex), | ||
(0.5, 5.5, (2, 3), tr.simplex), | ||
(np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.sum_to_1, tr.logodds])), | ||
(np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.univariate_sum_to_1, tr.logodds])), | ||
], | ||
) | ||
def test_uniform_other(self, lower, upper, size, transform): | ||
|
@@ -569,7 +582,11 @@ def test_uniform_other(self, lower, upper, size, transform): | |
def test_mvnormal_ordered(self, mu, cov, size, shape): | ||
initval = np.sort(np.random.randn(*shape)) | ||
model = self.build_model( | ||
pm.MvNormal, {"mu": mu, "cov": cov}, size=size, initval=initval, transform=tr.ordered | ||
pm.MvNormal, | ||
{"mu": mu, "cov": cov}, | ||
size=size, | ||
initval=initval, | ||
transform=tr.multivariate_ordered, | ||
) | ||
self.check_vectortransform_elementwise_logp(model) | ||
|
||
|
@@ -598,3 +615,31 @@ def test_discrete_trafo(): | |
with pytest.raises(ValueError) as err: | ||
pm.Binomial("a", n=5, p=0.5, transform="log") | ||
err.match("Transformations for discrete distributions") | ||
|
||
|
||
def test_transforms_ordered(): | ||
with pm.Model() as model: | ||
pm.Normal( | ||
"x_univariate", | ||
mu=[-3, -1, 1, 2], | ||
sigma=1, | ||
size=(10, 4), | ||
transform=tr.univariate_ordered, | ||
) | ||
|
||
log_prob = model.point_logps() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To make this test a bit more readable I would suggest including an equivalent You can do that via There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the def test_2d_univariate_sum_to_1():
with pm.Model() as model:
x_1d = pm.Normal(
"x_1d",
mu=[-3,-1,1,2],
sigma=1,
size=(10,4),
transform=tr.univariate_sum_to_1,
)
x_2d = pm.Normal(
"x_2d",
mu=[-3, -1, 1, 2],
sigma=1,
size=(10, 4),
transform=tr.univariate_sum_to_1,
)
log_p = model.compile_logp(sum=False)({"x_1d_sumto1__":np.ones((10,3))*0.25,"x_2d_sumto1__":np.zeros((10,3))*0.25})
np.testing.assert_allclose(log_p[0],log_p[1]) Fails with:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok funny. This is not the case if I use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PR looks good otherwise, should we investigate the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was this when comparing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh boy 😆 . Maybe it was just that. I thought I was comparing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. This was my mistake, sorry! Checked it properly with |
||
np.testing.assert_allclose(list(log_prob.values()), np.array([18.69])) | ||
|
||
|
||
def test_transforms_sumto1(): | ||
with pm.Model() as model: | ||
pm.Normal( | ||
"x", | ||
mu=[-3, -1, 1, 2], | ||
sigma=1, | ||
size=(10, 4), | ||
transform=tr.univariate_sum_to_1, | ||
) | ||
|
||
log_prob = model.point_logps() | ||
np.testing.assert_allclose(list(log_prob.values()), np.array([-56.76])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps a more direct name?