Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate Flax dataclass to the newest JAX pytree keypath API. #2921

Merged
merged 1 commit into from
Mar 8, 2023

Conversation

copybara-service[bot]
Copy link

@copybara-service copybara-service bot commented Mar 2, 2023

Migrate Flax dataclass to the newest JAX pytree keypath API.

  • flax.struct.dataclass already registered via register_keypath, and this one only changes it to latest API.

  • flax.core.FrozenDict was registered so that flattening a frozen dict should be the same as flattening the underlying dict. This makes its serialization backward-compatible.

@copybara-service copybara-service bot force-pushed the test_513588016 branch 5 times, most recently from 100b8bf to 32877d3 Compare March 6, 2023 20:29
Returns:
A flattened version of this FrozenDict instance.
"""
return ((jax.tree_util.GetAttrKey('dict'), self._dict),), ()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is tricky. There should be a way to return no key at all since this is more of a proxy type.

Copy link
Collaborator

@cgarciae cgarciae Mar 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, even though it would be slightly slower, we could do a single-level flattening e.g:

children = tuple(((jax.tree_util.GetAttrKey(key), value) for key, value in self._dict.items())
return children, ()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Cristian! Your comment was addressed :)

@copybara-service copybara-service bot changed the title Migrate Flax dataclass to the newest JAX pytree API. Migrate Flax dataclass to the newest JAX pytree keypath API. Mar 8, 2023
@copybara-service copybara-service bot force-pushed the test_513588016 branch 2 times, most recently from 080d355 to 544cc60 Compare March 8, 2023 02:05
@codecov-commenter
Copy link

codecov-commenter commented Mar 8, 2023

Codecov Report

Merging #2921 (5ad0344) into main (34823e2) will decrease coverage by 0.05%.
The diff coverage is 73.33%.

@@            Coverage Diff             @@
##             main    #2921      +/-   ##
==========================================
- Coverage   81.81%   81.76%   -0.05%     
==========================================
  Files          55       55              
  Lines        5900     5906       +6     
==========================================
+ Hits         4827     4829       +2     
- Misses       1073     1077       +4     
Impacted Files Coverage Δ
flax/struct.py 91.30% <50.00%> (-5.62%) ⬇️
flax/core/frozen_dict.py 96.06% <100.00%> (+0.06%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@copybara-service copybara-service bot force-pushed the test_513588016 branch 2 times, most recently from 5c65ba2 to 5ad0344 Compare March 8, 2023 19:27
* `flax.struct.dataclass` already registered via `register_keypath`, and this one only changes it to latest API.

* `flax.core.FrozenDict` was registered so that flattening a frozen dict should be the same as flattening the underlying dict. This makes its serialization backward-compatible.

PiperOrigin-RevId: 515096454
@copybara-service copybara-service bot merged commit c7fa10a into main Mar 8, 2023
@copybara-service copybara-service bot deleted the test_513588016 branch March 8, 2023 19:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants