Skip to content

Commit

Permalink
Disable registration of dataclasses defined in __main__ with JAX tree…
Browse files Browse the repository at this point in the history
… util.

This avoids pickling failures of the sort _pickle.PicklingError: Can't pickle <functools._lru_cache_wrapper object>: it's not the same object as register_dataclass_type_with_jax_tree_util.

PiperOrigin-RevId: 636154723
  • Loading branch information
hamzamerzic authored and ChexDev committed May 22, 2024
1 parent 08ff475 commit 63edbff
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,12 @@ def _getstate(self):
# by JAX.
#
# See internal dataclass_test for unit tests demonstrating the problems.
register_dataclass_type_with_jax_tree_util(dcls)
# The registration below may result in pickling failures of the sort
# _pickle.PicklingError: Can't pickle <functools._lru_cache_wrapper object>:
# it's not the same object as register_dataclass_type_with_jax_tree_util
# for modules defined in __main__ so we disable registration in this case.
if dcls.__module__ != "__main__":
register_dataclass_type_with_jax_tree_util(dcls)

# Patch __setstate__ to register the dataclass on deserialization.
def _setstate(self, state):
Expand Down

0 comments on commit 63edbff

Please sign in to comment.