Skip to content

Commit

Permalink
Implement a new unflatten for chex.dataclass that avoids __init__ whi…
Browse files Browse the repository at this point in the history
…le keeping the (un)flattened order unchanged.

PiperOrigin-RevId: 500963457
  • Loading branch information
ChexDev authored and ChexDev committed Jan 10, 2023
1 parent ac35c60 commit 0e0b5c1
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,15 @@ def _init(self, *args, **kwargs):
return dcls


def _dataclass_unflatten(dcls, keys, values):
dcls_object = dcls.__new__(dcls)
attribute_dict = dict(zip(keys, values))
# Looping over fields instead of keys & values preserves the field order.
for field in dataclasses.fields(dcls):
object.__setattr__(dcls_object, field.name, attribute_dict[field.name])
return dcls_object


def register_dataclass_type_with_jax_tree_util(data_class):
"""Register an existing dataclass so JAX knows how to handle it.
Expand All @@ -242,7 +251,7 @@ def register_dataclass_type_with_jax_tree_util(data_class):
in instance.__dict__.
"""
flatten = lambda d: jax.util.unzip2(sorted(d.__dict__.items()))[::-1]
unflatten = lambda keys, values: data_class(**dict(zip(keys, values)))
unflatten = functools.partial(_dataclass_unflatten, data_class)
try:
jax.tree_util.register_pytree_node(
nodetype=data_class, flatten_func=flatten, unflatten_func=unflatten)
Expand Down

0 comments on commit 0e0b5c1

Please sign in to comment.