-
Notifications
You must be signed in to change notification settings - Fork 663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate Flax dataclass to the newest JAX pytree keypath API. #2921
Conversation
100b8bf
to
32877d3
Compare
flax/core/frozen_dict.py
Outdated
Returns: | ||
A flattened version of this FrozenDict instance. | ||
""" | ||
return ((jax.tree_util.GetAttrKey('dict'), self._dict),), () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is tricky. There should be a way to return no key at all since this is more of a proxy type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, even though it would be slightly slower, we could do a single-level flattening e.g:
children = tuple(((jax.tree_util.GetAttrKey(key), value) for key, value in self._dict.items())
return children, ()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Cristian! Your comment was addressed :)
32877d3
to
7c74517
Compare
080d355
to
544cc60
Compare
Codecov Report
@@ Coverage Diff @@
## main #2921 +/- ##
==========================================
- Coverage 81.81% 81.76% -0.05%
==========================================
Files 55 55
Lines 5900 5906 +6
==========================================
+ Hits 4827 4829 +2
- Misses 1073 1077 +4
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
5c65ba2
to
5ad0344
Compare
* `flax.struct.dataclass` already registered via `register_keypath`, and this one only changes it to latest API. * `flax.core.FrozenDict` was registered so that flattening a frozen dict should be the same as flattening the underlying dict. This makes its serialization backward-compatible. PiperOrigin-RevId: 515096454
5ad0344
to
c7fa10a
Compare
Migrate Flax dataclass to the newest JAX pytree keypath API.
flax.struct.dataclass
already registered viaregister_keypath
, and this one only changes it to latest API.flax.core.FrozenDict
was registered so that flattening a frozen dict should be the same as flattening the underlying dict. This makes its serialization backward-compatible.