From 01ad71bc226443e7b5034fe9545df6b84659969c Mon Sep 17 00:00:00 2001 From: ChexDev Date: Tue, 10 Jan 2023 07:11:34 -0800 Subject: [PATCH] Implement a new unflatten for chex.dataclass that avoids __init__ while keeping the (un)flattened order unchanged. PiperOrigin-RevId: 500985733 --- chex/_src/dataclass.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/chex/_src/dataclass.py b/chex/_src/dataclass.py index 6b2a1b63..90d8a72f 100644 --- a/chex/_src/dataclass.py +++ b/chex/_src/dataclass.py @@ -228,6 +228,15 @@ def _init(self, *args, **kwargs): return dcls +def _dataclass_unflatten(dcls, keys, values): + dcls_object = dcls.__new__(dcls) + attribute_dict = dict(zip(keys, values)) + # Looping over fields instead of keys & values preserves the field order. + for field in dataclasses.fields(dcls): + object.__setattr__(dcls_object, field.name, attribute_dict[field.name]) + return dcls_object + + def register_dataclass_type_with_jax_tree_util(data_class): """Register an existing dataclass so JAX knows how to handle it. @@ -242,7 +251,7 @@ def register_dataclass_type_with_jax_tree_util(data_class): in instance.__dict__. """ flatten = lambda d: jax.util.unzip2(sorted(d.__dict__.items()))[::-1] - unflatten = lambda keys, values: data_class(**dict(zip(keys, values))) + unflatten = functools.partial(_dataclass_unflatten, data_class) try: jax.tree_util.register_pytree_node( nodetype=data_class, flatten_func=flatten, unflatten_func=unflatten)