-
Notifications
You must be signed in to change notification settings - Fork 663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] Add Sphinx Docs #3678
[nnx] Add Sphinx Docs #3678
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #3678 +/- ##
==========================================
+ Coverage 56.14% 57.71% +1.56%
==========================================
Files 102 103 +1
Lines 12186 12289 +103
==========================================
+ Hits 6842 7092 +250
+ Misses 5344 5197 -147 ☔ View full report in Codecov by Sentry. |
ecd9dcd
to
ab2a30d
Compare
ab2a30d
to
7837c59
Compare
|
||
## 5. Create the `TrainState` | ||
|
||
In Flax, a common practice is to use a dataclass to encapsulate the training state, including the step number, parameters, and optimizer state. The [`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/flax.training.html#train-state) class is ideal for basic use cases, simplifying the process by allowing you to pass a single argument to functions like `train_step`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, the fusion of Flax and NNX!
Do we have to use TrainState
in this example? Does it make sense to make it more self-contained and free of abstractions from non-NNX parts of Flax?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Functional API works 🙂. However, I've been thinking of creating a Module-based TrainState
that would be even more convenient that the current Pytree API.
|
||
y, state = forward(static, state, x=jnp.ones((1, 2))) | ||
# 5. Update the state of the original Module | ||
model.update(state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is model.update(state)
different from static.merge(state)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update
mutates the original references with the new state, merge
creates new references with the given state.
What does this PR do?
Preview: https://flax--3678.org.readthedocs.build/en/3678/experimental/nnx/index.html