Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

filtering broken for PyTrees with distrax.Distribution objects #269

Closed
kylehkhsu opened this issue Feb 27, 2023 · 2 comments
Closed

filtering broken for PyTrees with distrax.Distribution objects #269

kylehkhsu opened this issue Feb 27, 2023 · 2 comments

Comments

@kylehkhsu
Copy link

MWE: running the MWE from #252 results in

Traceback (most recent call last):
  File "/iris/u/kylehsu/code/disentangle/scripts/test.py", line 38, in <module>
    opt_state = optim.init(eqx.filter(p, eqx.is_array))
  File "/iris/u/kylehsu/miniconda3/envs/disentangle/lib/python3.10/site-packages/equinox/filters.py", line 128, in filter
    return jtu.tree_map(
  File "/iris/u/kylehsu/miniconda3/envs/disentangle/lib/python3.10/site-packages/jax/_src/tree_util.py", line 206, in tree_map
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
  File "/iris/u/kylehsu/miniconda3/envs/disentangle/lib/python3.10/site-packages/jax/_src/tree_util.py", line 206, in <listcomp>
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Mismatch custom node data: ([dtype('float32'), 10, True, True], [False, False, False, False], PyTreeDef({'_batch_shape': (), '_bijector': CustomNode(Chain[([1, 1, True, True], [False, False, False, False], PyTreeDef({'_bijectors': [CustomNode(Block[([1, 1, True, True, 1], [False, False, False, False, False], PyTreeDef({'_bijector': CustomNode(Shift[([10, 0, 0, True, True, True], [False, False, False, False, False, False], PyTreeDef({'_batch_shape': (*,), '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_shift': *}))], [None, None, None, None, None, None]), '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_ndims': *}))], [None, None, None, None, None]), CustomNode(DiagLinear[([True, dtype('float32'), 10, 1, 1, True, True, <bound method Block.forward of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>, <bound method Block.forward_log_det_jacobian of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>, <bound method Block.inverse of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>, <bound method Block.inverse_and_log_det of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>, <bound method Block.inverse_log_det_jacobian of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>], [False, False, False, False, False, False, False, False, False, False, False, False], PyTreeDef({'_batch_shape': (), '_bijector': CustomNode(Block[([1, 1, True, True, 1], [False, False, False, False, False], PyTreeDef({'_bijector': CustomNode(ScalarAffine[([10, 0, 0, True, True, True, True, True, 0.0], [False, False, False, False, False, False, False, False, False], PyTreeDef({'_batch_shape': (*,), '_event_ndims_in': *, '_event_ndims_out': *, '_inv_scale': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_log_scale': *, '_scale': *, '_shift': *}))], [None, None, None, None, None, None, None, None, None]), '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_ndims': *}))], [None, None, None, None, None]), '_diag': *, '_dtype': *, '_event_dims': *, '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, 'forward': *, 'forward_log_det_jacobian': *, 'inverse': *, 'inverse_and_log_det': *, 'inverse_log_det_jacobian': *}))], [None, None, None, None, None, None, None, None, None, None, None, None])], '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *}))], [None, None, None, None]), '_distribution': CustomNode(Independent[([1], [False], PyTreeDef({'_distribution': CustomNode(Normal[([True, True], [False, False], PyTreeDef({'_loc': *, '_scale': *}))], [None, None]), '_reinterpreted_batch_ndims': *}))], [None]), '_dtype': *, '_event_shape': (*,), '_loc': *, '_scale': CustomNode(DiagLinear[([True, dtype('float32'), 10, 1, 1, True, True, <bound method Block.forward of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>, <bound method Block.forward_log_det_jacobian of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>, <bound method Block.inverse of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>, <bound method Block.inverse_and_log_det of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>, <bound method Block.inverse_log_det_jacobian of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>], [False, False, False, False, False, False, False, False, False, False, False, False], PyTreeDef({'_batch_shape': (), '_bijector': CustomNode(Block[([1, 1, True, True, 1], [False, False, False, False, False], PyTreeDef({'_bijector': CustomNode(ScalarAffine[([10, 0, 0, True, True, True, True, True, 0.0], [False, False, False, False, False, False, False, False, False], PyTreeDef({'_batch_shape': (*,), '_event_ndims_in': *, '_event_ndims_out': *, '_inv_scale': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_log_scale': *, '_scale': *, '_shift': *}))], [None, None, None, None, None, None, None, None, None]), '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_ndims': *}))], [None, None, None, None, None]), '_diag': *, '_dtype': *, '_event_dims': *, '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, 'forward': *, 'forward_log_det_jacobian': *, 'inverse': *, 'inverse_and_log_det': *, 'inverse_log_det_jacobian': *}))], [None, None, None, None, None, None, None, None, None, None, None, None]), '_scale_diag': *})) != ([None, None, None, None, None, None, None, dtype('float32'), 10, None, None, None, None, None, None], [True, True, True, True, True, True, True, False, False, True, True, True, True, True, True], PyTreeDef({'_batch_shape': (), '_bijector': CustomNode(Chain[([None, None, None, None, None, 1, 1, True, True], [True, True, True, True, True, False, False, False, False], PyTreeDef({'_bijectors': [CustomNode(Block[([None, 1, 1, True, True, 1], [True, False, False, False, False, False], PyTreeDef({'_bijector': CustomNode(Shift[([10, 0, 0, True, True, None], [False, False, False, False, False, True], PyTreeDef({'_batch_shape': (*,), '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_shift': *}))], [None, None, None, None, None, *]), '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_ndims': *}))], [*, None, None, None, None, None]), CustomNode(DiagLinear[([None, None, None, None, dtype('float32'), 10, 1, 1, True, True, <bound method Block.forward of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>, <bound method Block.forward_log_det_jacobian of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>, <bound method Block.inverse of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>, <bound method Block.inverse_and_log_det of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>, <bound method Block.inverse_log_det_jacobian of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>], [True, True, True, True, False, False, False, False, False, False, False, False, False, False, False], PyTreeDef({'_batch_shape': (), '_bijector': CustomNode(Block[([None, None, None, 1, 1, True, True, 1], [True, True, True, False, False, False, False, False], PyTreeDef({'_bijector': CustomNode(ScalarAffine[([10, 0, 0, None, True, True, None, None, 0.0], [False, False, False, True, False, False, True, True, False], PyTreeDef({'_batch_shape': (*,), '_event_ndims_in': *, '_event_ndims_out': *, '_inv_scale': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_log_scale': *, '_scale': *, '_shift': *}))], [None, None, None, *, None, None, *, *, None]), '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_ndims': *}))], [*, *, *, None, None, None, None, None]), '_diag': *, '_dtype': *, '_event_dims': *, '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, 'forward': *, 'forward_log_det_jacobian': *, 'inverse': *, 'inverse_and_log_det': *, 'inverse_log_det_jacobian': *}))], [*, *, *, *, None, None, None, None, None, None, None, None, None, None, None])], '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *}))], [*, *, *, *, *, None, None, None, None]), '_distribution': CustomNode(Independent[([None, None, 1], [True, True, False], PyTreeDef({'_distribution': CustomNode(Normal[([None, None], [True, True], PyTreeDef({'_loc': *, '_scale': *}))], [*, *]), '_reinterpreted_batch_ndims': *}))], [*, *, None]), '_dtype': *, '_event_shape': (*,), '_loc': *, '_scale': CustomNode(DiagLinear[([None, None, None, None, dtype('float32'), 10, 1, 1, True, True, <bound method Block.forward of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>, <bound method Block.forward_log_det_jacobian of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>, <bound method Block.inverse of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>, <bound method Block.inverse_and_log_det of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>, <bound method Block.inverse_log_det_jacobian of <distrax._src.bijectors.block.Block object at 0x7f98393aab60>>], [True, True, True, True, False, False, False, False, False, False, False, False, False, False, False], PyTreeDef({'_batch_shape': (), '_bijector': CustomNode(Block[([None, None, None, 1, 1, True, True, 1], [True, True, True, False, False, False, False, False], PyTreeDef({'_bijector': CustomNode(ScalarAffine[([10, 0, 0, None, True, True, None, None, 0.0], [False, False, False, True, False, False, True, True, False], PyTreeDef({'_batch_shape': (*,), '_event_ndims_in': *, '_event_ndims_out': *, '_inv_scale': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_log_scale': *, '_scale': *, '_shift': *}))], [None, None, None, *, None, None, *, *, None]), '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, '_ndims': *}))], [*, *, *, None, None, None, None, None]), '_diag': *, '_dtype': *, '_event_dims': *, '_event_ndims_in': *, '_event_ndims_out': *, '_is_constant_jacobian': *, '_is_constant_log_det': *, 'forward': *, 'forward_log_det_jacobian': *, 'inverse': *, 'inverse_and_log_det': *, 'inverse_log_det_jacobian': *}))], [*, *, *, *, None, None, None, None, None, None, None, None, None, None, None]), '_scale_diag': *})); value: <distrax._src.distributions.mvn_diag.MultivariateNormalDiag object at 0x7f983956fb80>.

My package versions are jaxlib==0.4.4, equinox==0.10.1, and distrax==0.1.3.

The discussion at #252 leads me to believe that this previously did not happen. I would like to be able to have my eqx.Modules have distrax.Distribution attributes/subtrees.

@kylehkhsu
Copy link
Author

kylehkhsu commented Feb 27, 2023

Ah, just noticed this distrax issue was discussed in #252. I suppose I should not expect compatibility with distrax. Closing, but feel free to re-open.

@patrick-kidger
Copy link
Owner

I've since found that the easiest possible fix for this is probably something like

class FixedDistrax(eqx.Module):
    cls: type
    args: PyTree[Any]
    kwargs: PyTrer[Any]

    def __init__(self, cls, *args, **kwargs):
        self.cls = cls
        self.args = args
        self.kwargs = kwargs

    def log_prior(self, x):
        return self.cls(*self.args, **self.kwargs).log_prior(x)

prior = FixedDistrax(distrax.MultivariateNormalDiag, mu, sigma)

So that (a) the Distrax objects are transitory and never need to be flattened/unflattened, and (b) all of their args and kwargs are part of the pytree.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants