Replies: 4 comments 7 replies
-
Thanks for posting this update, I remembered you commented on this performance consideration somewhere months back and couldn't find it. Been doing this split/merge w/ standard JAX transforms in my code just in case (and to stay as pure JAX as possible). If a PR goes through to address this, I'll try switching to the NNX transforms! 👍 |
Beta Was this translation helpful? Give feedback.
-
Thanks for posting this! Is there a way to do update with metrics? Or only graphdef, state = nnx.split((model, optimizer, metrics))
...
nnx.update((model, optimizer, metrics), state) Cause now it raises an error:
And also it will be great to create page with speed up tips for NNX API! |
Beta Was this translation helpful? Give feedback.
-
@cgarciae As I mention above, I've been sticking to the split/merge + JAX transforms to future proof against any performance hits. However, I would consider switching to NNX transforms for my current dev if the expectation is that the Rust extension would definitively close the performance gap. Can you comment on the expected gains with flaxlib? |
Beta Was this translation helpful? Give feedback.
-
@cgarciae in your example, at the end, |
Beta Was this translation helpful? Give feedback.
-
Currently
nnx.jit
traverses the object graph in Python. This is slow and primarily affects the small model regime, as the Python overhead starts to disappear as the model's width grows. To solve this in general, we will be developing a Rust extension calledflaxlib
(see first steps in #4196) to speedup some of the traversal logic ingraph.py
, similar to how JAX solved the same issue withjaxlib
for standard pytrees.Meanwhile, there is a pattern you can use to remove the python overhead using regular
jax.jit
+nnx.split
/nnx.merge
to stage out the traversal logic. Take this code that usesnnx.jit
as an example:To speed it up you can use
nnx.split
before starting the training loop to create agraphdef
andstate
for the NNX objects which are fast to traverse, and then callmerge
+split
inside thejax.jit
-decorated function so they only run once during tracing:After the training loop is done (or whenever need)
nnx.update
can be used to updatemodel
,optimizer
, andmetrics
to a newstate
.Beta Was this translation helpful? Give feedback.
All reactions