-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add NNX surgery guide #4005
Add NNX surgery guide #4005
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
docs/nnx/surgery.md
Outdated
nnx.update(good_model, new_state) | ||
|
||
# They are equivalent. The way RNGs are propagated in the new model is also preserved. | ||
np.testing.assert_allclose(simple_model(x), good_model(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused by this section, it seems to be doing the same operation in two different ways. I was thinking we could show case the use of nnx.jit
here and use something like:
old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))
@functools.partial(nnx.jit, donate_argnums=0, static_argnums=1)
def partial_init(old_state, rngs):
model = TwoLayerMLP(4, rngs=rngs)
# add existing create
nnx.update(model, old_state)
# create new state
model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=rngs))
return model
model = partial_init(old_state, nnx.Rngs(0))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, for some reason the code you proposed didn't really save memory.
I added some lines to explain this section and also showed how to print the memory cost at each step. You can see that the jax.jit
example skipped the creation of extra unused params in LoRALinear
. I don't know why your nnx.jit
version doesn't repro this...
You can run this piece of code in a standalone python file to verify what I've seen. If nnx.jit
can do this I'm happy to switch to nnx.jit
instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad, I put nnx.update
in the middle but it should be at the end else the model.linear1 =
assignment overrides the updated state. Here is the fixed code:
class TwoLayerMLP(nnx.Module):
def __init__(self, dim, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
self.linear2 = nnx.Linear(dim, dim, rngs=rngs)
def __call__(self, x):
x = self.linear1(x)
return self.linear2(x)
old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))
assert len(jax.live_arrays()) == 4
@partial(nnx.jit, donate_argnums=0, static_argnums=1)
def partial_init(old_state, rngs):
model = TwoLayerMLP(4, rngs=rngs)
# create new state
model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=rngs)
# add existing create
nnx.update(model, old_state)
return model
model = partial_init(old_state, nnx.Rngs(0))
assert len(jax.live_arrays()) == 6
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Added your code and replaced the jax.jit
code. Thank you!
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #4005 +/- ##
======================================
Coverage 0.00% 0.00%
======================================
Files 106 107 +1
Lines 13582 13689 +107
======================================
- Misses 13582 13689 +107 ☔ View full report in Codecov by Sentry. |
Add a guide for common model surgery scenarios and how to use NNX state and modules to fluently surgery around.