Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NNX surgery guide #4005

Merged
merged 1 commit into from
Jun 26, 2024
Merged

Add NNX surgery guide #4005

merged 1 commit into from
Jun 26, 2024

Conversation

IvyZX
Copy link
Collaborator

@IvyZX IvyZX commented Jun 18, 2024

Add a guide for common model surgery scenarios and how to use NNX state and modules to fluently surgery around.

@IvyZX IvyZX requested review from cgarciae and chiamp June 18, 2024 21:09
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

nnx.update(good_model, new_state)

# They are equivalent. The way RNGs are propagated in the new model is also preserved.
np.testing.assert_allclose(simple_model(x), good_model(x))
Copy link
Collaborator

@cgarciae cgarciae Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused by this section, it seems to be doing the same operation in two different ways. I was thinking we could show case the use of nnx.jit here and use something like:

old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))

@functools.partial(nnx.jit, donate_argnums=0, static_argnums=1)
def partial_init(old_state, rngs):
  model = TwoLayerMLP(4, rngs=rngs)
  # add existing create
  nnx.update(model, old_state)
  # create new state
  model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=rngs))
  return model
  
model = partial_init(old_state, nnx.Rngs(0))

Copy link
Collaborator Author

@IvyZX IvyZX Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, for some reason the code you proposed didn't really save memory.

I added some lines to explain this section and also showed how to print the memory cost at each step. You can see that the jax.jit example skipped the creation of extra unused params in LoRALinear. I don't know why your nnx.jit version doesn't repro this...

You can run this piece of code in a standalone python file to verify what I've seen. If nnx.jit can do this I'm happy to switch to nnx.jit instead.

Copy link
Collaborator

@cgarciae cgarciae Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, I put nnx.update in the middle but it should be at the end else the model.linear1 = assignment overrides the updated state. Here is the fixed code:

  class TwoLayerMLP(nnx.Module):
      def __init__(self, dim, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
        self.linear2 = nnx.Linear(dim, dim, rngs=rngs)

      def __call__(self, x):
        x = self.linear1(x)
        return self.linear2(x)

    old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))

    assert len(jax.live_arrays()) == 4

    @partial(nnx.jit, donate_argnums=0, static_argnums=1)
    def partial_init(old_state, rngs):
      model = TwoLayerMLP(4, rngs=rngs)
      # create new state
      model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=rngs)
      # add existing create
      nnx.update(model, old_state)
      return model

    model = partial_init(old_state, nnx.Rngs(0))

    assert len(jax.live_arrays()) == 6

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Added your code and replaced the jax.jit code. Thank you!

@codecov-commenter
Copy link

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 0.00%. Comparing base (31adb00) to head (7eeaa7b).
Report is 26 commits behind head on main.

Additional details and impacted files
@@          Coverage Diff           @@
##            main   #4005    +/-   ##
======================================
  Coverage   0.00%   0.00%            
======================================
  Files        106     107     +1     
  Lines      13582   13689   +107     
======================================
- Misses     13582   13689   +107     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@IvyZX IvyZX requested a review from cgarciae June 25, 2024 18:47
@8bitmp3 8bitmp3 self-requested a review June 25, 2024 23:06
@copybara-service copybara-service bot merged commit 0c19b6b into google:main Jun 26, 2024
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants