From 1ff8743909a9755b13398d098c692d37cac62226 Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Sun, 4 Jun 2023 22:50:26 +0100 Subject: [PATCH] Fix incorrect unflattenning of inverse transforms --- numpyro/distributions/transforms.py | 7 ++++--- test/test_transforms.py | 2 ++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 6bd86f14a..a3ec798e1 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -163,9 +163,10 @@ def inverse_shape(self, shape): def tree_flatten(self): return (self._inv,), (("_inv",), dict()) - @classmethod - def tree_unflatten(cls, aux_data, params): - return cls(params) + def __eq__(self, other): + if not isinstance(other, _InverseTransform): + return False + return self._inv == other._inv class AbsTransform(ParameterFreeTransform): diff --git a/test/test_transforms.py b/test/test_transforms.py index 54316dd88..637b6d969 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -140,6 +140,8 @@ def out_t(transform, x): assert jitted_in_t(transform, 1.0) == 1.0 assert jitted_out_t(transform, 1.0) == transform + assert jitted_out_t(transform.inv, 1.0) == transform.inv + assert jnp.allclose( vmap(in_t, in_axes=(None, 0), out_axes=0)(transform, jnp.ones(3)), jnp.ones(3),