Skip to content

Commit

Permalink
Register dataclass at definition so that unpickling the dataclass obj…
Browse files Browse the repository at this point in the history
…ect, type, and treedef works in most cases (with one exception).

PiperOrigin-RevId: 612121935
  • Loading branch information
kho authored and ChexDev committed Mar 7, 2024
1 parent 91de1eb commit 0721454
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,31 @@ def _replace(self, **kwargs):
def _getstate(self):
return self.__dict__

# Patch __setstate__ to register the object on deserialization.
# Register the dataclass at definition. As long as the dataclass is defined
# outside __main__, this is sufficient to make JAX's PyTree registry
# recognize the dataclass and the dataclass' custom PyTreeDef, especially
# when unpickling either the dataclass object, its type, or its PyTreeDef,
# in a different process, because the defining module will be imported.
#
# However, if the dataclass is defined in __main__, unpickling in a
# subprocess does not trigger re-registration. Therefore we also need to
# register when deserializing the object, or construction (e.g. when the
# dataclass type is being unpickled). Unfortunately, there is not yet a way
# to trigger re-registration when the treedef is unpickled as that's handled
# by JAX.
#
# See internal dataclass_test for unit tests demonstrating the problems.
register_dataclass_type_with_jax_tree_util(dcls)

# Patch __setstate__ to register the dataclass on deserialization.
def _setstate(self, state):
register_dataclass_type_with_jax_tree_util(dcls)
self.__dict__.update(state)

orig_init = dcls.__init__

# Patch object's __init__ such that the class is registered on creation if
# it is not registered on deserialization.
# Patch __init__ such that the dataclass is registered on creation if it is
# not registered on deserialization.
@functools.wraps(orig_init)
def _init(self, *args, **kwargs):
register_dataclass_type_with_jax_tree_util(dcls)
Expand Down

0 comments on commit 0721454

Please sign in to comment.