Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] always fold_in on fork + new ForkedKeys return type #3722

Merged
merged 1 commit into from
Mar 1, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Feb 27, 2024

What does this PR do?

  • Rngs.fork now returns a single ForkedKeys type instead of different types depending on the overload (overloads are removed). ForkedKeys is a Mapping[str, Array] with the keys, and contains .splits and .broadcasts attributes of type dict[str, Array] for cases when you need to tell apart the splits from the broadcasts.
  • Rngs.fork now folds all keys so there is no need to maintain / propagate hashable data.
  • Because of the previous RngStream.counts: list[int] is replaced with RngStream.count: int.

@cgarciae cgarciae force-pushed the nnx-dynamic-rngs branch 4 times, most recently from 839dff7 to 159e542 Compare February 27, 2024 14:21
@cgarciae cgarciae changed the title [nnx] dynamic RngStream.count + ForkRngs [nnx] always fold_in on fork + ForkRngs Feb 27, 2024
@cgarciae cgarciae changed the title [nnx] always fold_in on fork + ForkRngs [nnx] always fold_in on fork + new ForkedKeys return type Feb 28, 2024
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 94.73684% with 5 lines in your changes are missing coverage. Please review.

Project coverage is 59.05%. Comparing base (1abfa87) to head (3e69350).
Report is 2 commits behind head on main.

Files Patch % Lines
flax/experimental/nnx/nnx/rnglib.py 94.00% 3 Missing ⚠️
flax/experimental/nnx/nnx/transforms.py 75.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3722      +/-   ##
==========================================
+ Coverage   59.00%   59.05%   +0.04%     
==========================================
  Files         103      103              
  Lines       12422    12438      +16     
==========================================
+ Hits         7330     7345      +15     
- Misses       5092     5093       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@copybara-service copybara-service bot merged commit acba0bf into main Mar 1, 2024
21 checks passed
@copybara-service copybara-service bot deleted the nnx-dynamic-rngs branch March 1, 2024 13:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants