Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect unflattenning of inverse transforms #1600

Merged
merged 1 commit into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,10 @@ def inverse_shape(self, shape):
def tree_flatten(self):
return (self._inv,), (("_inv",), dict())

@classmethod
def tree_unflatten(cls, aux_data, params):
return cls(params)
def __eq__(self, other):
if not isinstance(other, _InverseTransform):
return False
return self._inv == other._inv


class AbsTransform(ParameterFreeTransform):
Expand Down
2 changes: 2 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def out_t(transform, x):
assert jitted_in_t(transform, 1.0) == 1.0
assert jitted_out_t(transform, 1.0) == transform

assert jitted_out_t(transform.inv, 1.0) == transform.inv

assert jnp.allclose(
vmap(in_t, in_axes=(None, 0), out_axes=0)(transform, jnp.ones(3)),
jnp.ones(3),
Expand Down