-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add nnx.training
#3782
Add nnx.training
#3782
Conversation
17ec8e0
to
e3949d3
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3782 +/- ##
==========================================
+ Coverage 60.31% 60.86% +0.55%
==========================================
Files 101 105 +4
Lines 12860 13045 +185
==========================================
+ Hits 7756 7940 +184
- Misses 5104 5105 +1 ☔ View full report in Codecov by Sentry. |
d06c963
to
297c388
Compare
297c388
to
5b8f202
Compare
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
5b8f202
to
d538b16
Compare
@@ -170,18 +174,23 @@ This function takes the `state` and a data `batch` and does the following: | |||
|
|||
```{code-cell} ipython3 | |||
@jax.jit | |||
def train_step(state: TrainState, batch): | |||
def train_step(state: nnx.State, static: nnx.GraphDef, batch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we could use nnx.jit
here to make it simpler?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Should we show some examples elsewhere on how to use NNX with both nnx.jit
and jax.jit
; in toy examples, and/or a guide with a side-by-side comparison, similar to the upgrade guides?
d538b16
to
560137b
Compare
560137b
to
dd06693
Compare
dd06693
to
7f778be
Compare
7f778be
to
a294f84
Compare
a294f84
to
eb8d13b
Compare
cdd1c5d
to
4726935
Compare
4726935
to
8852a09
Compare
8852a09
to
543686b
Compare
Add
nnx.training
.API reference for
nnx.metrics
andnnx.optimizer
.Had to importchex.ArrayTree
to make typing happy. Is there an alternative where we don't have to use this dependency? e.g. annotating with# type: ignore
.Renamed
optimizer.update
method tooptimizer.apply_gradients
method, becauseoptimizer.update
was being called instead of the base classnnx.Module.update
method when usingnnx.jit
.Added
OptState
andMetricState
Variable
wrapper classes fornnx.optimizer.Optimizer
andnnx.metrics.Metric
, respectively.Updated
toy_examples/02_lifted_transforms.py
to usennx.optimizer.Optimizer
. Training loop runs slightly slower.Updated
docs/experimental/nnx/mnist_tutorial.ipynb
to usennx.metrics.MultiMetric
andnnx.optimizer.Optimizer
.For some reason, I'm getting an
nbstripout
error on CI when I try to include the outputs of the mnist tutorial, so I have removed them for now.