Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static_argnums to nn.checkpoint #2457

Merged
merged 2 commits into from
Sep 14, 2022

Conversation

cgarciae
Copy link
Collaborator

What does this PR do?

Fixes #2452. Adds a static_argnums to nn.checkpoint and pipes it to jax.remat, copies argument docstring from JAX.

@cgarciae cgarciae changed the title Forward static_argnums to remat Add static_argnums to nn.checkpoint Sep 12, 2022
@cgarciae cgarciae requested a review from IvyZX September 12, 2022 22:11
@cgarciae cgarciae self-assigned this Sep 12, 2022
@cgarciae cgarciae added the Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) label Sep 12, 2022
@cgarciae
Copy link
Collaborator Author

I want to add another test, lets wait before merging.

@cgarciae
Copy link
Collaborator Author

I discussed this PR with @jheek, final behaviour:

  • subtracts -1 from each static_argnum in transforms.checkpoint because self (the Module) is never passed to the lifted transformation.
  • adds +2 to each static_argnum in lift.checkpoint because 2 arguments (variable_groups, rng_groups) are always passed to the inner function in the lifted transformation.

Added an additional test. Should be ready for a final review + internal merge.

@codecov-commenter
Copy link

Codecov Report

Merging #2457 (f384b33) into main (e320e11) will increase coverage by 0.00%.
The diff coverage is 100.00%.

@@           Coverage Diff           @@
##             main    #2457   +/-   ##
=======================================
  Coverage   79.66%   79.67%           
=======================================
  Files          49       49           
  Lines        4982     4984    +2     
=======================================
+ Hits         3969     3971    +2     
  Misses       1013     1013           
Impacted Files Coverage Δ
flax/core/lift.py 95.81% <100.00%> (+<0.01%) ⬆️
flax/linen/transforms.py 94.06% <100.00%> (+0.02%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@cgarciae cgarciae requested a review from IvyZX September 14, 2022 16:57
@copybara-service copybara-service bot merged commit 59280a3 into google:main Sep 14, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) pull ready
Projects
None yet
Development

Successfully merging this pull request may close these issues.

flax.linen.remat with concrete=True doesn't work with jax 0.3.17
3 participants