Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nnx.training #3782

Merged
merged 1 commit into from
Apr 1, 2024
Merged

Conversation

chiamp
Copy link
Collaborator

@chiamp chiamp commented Mar 22, 2024

Add nnx.training.

API reference for nnx.metrics and nnx.optimizer.

Had to import chex.ArrayTree to make typing happy. Is there an alternative where we don't have to use this dependency? e.g. annotating with # type: ignore.

Renamed optimizer.update method to optimizer.apply_gradients method, because optimizer.update was being called instead of the base class nnx.Module.update method when using nnx.jit.
Added OptState and MetricState Variable wrapper classes for nnx.optimizer.Optimizer and nnx.metrics.Metric, respectively.

Updated toy_examples/02_lifted_transforms.py to use nnx.optimizer.Optimizer. Training loop runs slightly slower.
Updated docs/experimental/nnx/mnist_tutorial.ipynb to use nnx.metrics.MultiMetric and nnx.optimizer.Optimizer.

For some reason, I'm getting an nbstripout error on CI when I try to include the outputs of the mnist tutorial, so I have removed them for now.

@chiamp chiamp self-assigned this Mar 22, 2024
@chiamp chiamp marked this pull request as draft March 22, 2024 00:40
@codecov-commenter
Copy link

codecov-commenter commented Mar 22, 2024

Codecov Report

Attention: Patch coverage is 99.45946% with 1 lines in your changes are missing coverage. Please review.

Project coverage is 60.86%. Comparing base (f2bdcd8) to head (543686b).

Files Patch % Lines
flax/experimental/nnx/nnx/training/metrics.py 97.77% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3782      +/-   ##
==========================================
+ Coverage   60.31%   60.86%   +0.55%     
==========================================
  Files         101      105       +4     
  Lines       12860    13045     +185     
==========================================
+ Hits         7756     7940     +184     
- Misses       5104     5105       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@chiamp chiamp force-pushed the nnx_module_subclass branch 2 times, most recently from d06c963 to 297c388 Compare March 22, 2024 02:20
@chiamp chiamp requested a review from cgarciae March 22, 2024 02:28
flax/experimental/nnx/nnx/training/metrics.py Outdated Show resolved Hide resolved
flax/experimental/nnx/nnx/training/metrics.py Outdated Show resolved Hide resolved
flax/experimental/nnx/nnx/training/metrics.py Outdated Show resolved Hide resolved
flax/experimental/nnx/nnx/training/metrics.py Outdated Show resolved Hide resolved
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@@ -170,18 +174,23 @@ This function takes the `state` and a data `batch` and does the following:

```{code-cell} ipython3
@jax.jit
def train_step(state: TrainState, batch):
def train_step(state: nnx.State, static: nnx.GraphDef, batch):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we could use nnx.jit here to make it simpler?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Should we show some examples elsewhere on how to use NNX with both nnx.jit and jax.jit; in toy examples, and/or a guide with a side-by-side comparison, similar to the upgrade guides?

@cgarciae cgarciae mentioned this pull request Mar 27, 2024
@chiamp chiamp force-pushed the nnx_module_subclass branch 3 times, most recently from cdd1c5d to 4726935 Compare March 29, 2024 20:08
@cgarciae cgarciae marked this pull request as ready for review April 1, 2024 19:50
@copybara-service copybara-service bot merged commit 5d9fd9f into google:main Apr 1, 2024
19 checks passed
@chiamp chiamp deleted the nnx_module_subclass branch April 1, 2024 22:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants