Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NNX] Add LoRA and LoRALinear to NNX #3929

Merged
merged 1 commit into from
May 23, 2024
Merged

Conversation

IvyZX
Copy link
Collaborator

@IvyZX IvyZX commented May 21, 2024

Provided here are two ways to add LoRA on any layer:

  • nnx.LoRA with base_module arg that can take any layer instance. Easy for model surgery on modules, but will render mismatching states.
  • nnx.LoRALinear which subclasses nnx.Linear and attach a simple nnx.LoRA module along the way. Param structure matches with nnx.Linear, but need a bit more surgery on runtime. This technique can be applied for any other NNX modules.

Not yet sure which approach is the best, so providing both at this moment.

@IvyZX IvyZX requested review from cgarciae and chiamp May 21, 2024 22:39
@IvyZX IvyZX force-pushed the lora branch 4 times, most recently from 9298031 to d22505d Compare May 22, 2024 00:56
@codecov-commenter
Copy link

codecov-commenter commented May 22, 2024

Codecov Report

Attention: Patch coverage is 0% with 43 lines in your changes are missing coverage. Please review.

Project coverage is 0.00%. Comparing base (2c7d7cd) to head (0f6dc0e).
Report is 46 commits behind head on main.

Files Patch % Lines
flax/experimental/nnx/nnx/nn/lora.py 0.00% 40 Missing ⚠️
flax/experimental/nnx/__init__.py 0.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main   #3929       +/-   ##
==========================================
- Coverage   60.43%   0.00%   -60.44%     
==========================================
  Files         105     102        -3     
  Lines       13263   13160      -103     
==========================================
- Hits         8015       0     -8015     
- Misses       5248   13160     +7912     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

def __call__(self, x: jax.Array):
out = x @ self.lora_a @ self.lora_b
if self.base_module is not None:
assert callable(self.base_module), "`base_module` must be callable."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets raise an error here instead

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return out


class LoRALinear(nnx.Linear):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might cleaner to inherit from LoRA and create a Linear in __init__ which is passed as the base_module to super().__init__(). This way you can remove __call__.

Copy link
Collaborator Author

@IvyZX IvyZX May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I provided LoRALinear is that this assumes the exact same param structure and API as Linear. This is also how PyTorch provides their LoRALinear.

If it's just a LoRA instance, the original linear weights will be one level below, inside base_module. I think users can easily create something like this on their own, no need to have a LoRALinear shortcut for it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.

in_features: int,
out_features: int,
lora_rank: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this ordering is a bit more intuitive?

Suggested change
in_features: int,
out_features: int,
lora_rank: int,
in_features: int,
lora_rank: int,
out_features: int,

Copy link
Collaborator Author

@IvyZX IvyZX May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! LoRA is a generic layer and this is indeed more intuitive.

Note that the PyTorch LoRALinear implementation is ordered (in, out, rank). To avoid this being confusing, I will make lora_rank a kwarg argument in our LoRALinear.

Comment on lines 113 to 116
self.lora_a = param_type(
kernel_init(rngs.params(), (in_features, lora_rank), param_dtype)
)
self.lora_b = param_type(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if there are more informative variable names we can use like self.lora_in and self.lora_out, or is a and b standard naming conventions for LoRA?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah unfortunately AFAIK a and b are the conventions here...

precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer function for the weight matrices.
use_lora_param_type: if yes, LoRA params will be of different Param type.
Copy link
Collaborator

@chiamp chiamp May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
use_lora_param_type: if yes, LoRA params will be of different Param type.
use_lora_param_type: if ``True``, LoRA params will be of different Param type.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking maybe it's better to make this argument an explicit type instead of a boolean... more customizability!

precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer function for the weight matrices.
use_lora_param_type: if yes, LoRA params will be of different Param type.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
use_lora_param_type: if yes, LoRA params will be of different Param type.
use_lora_param_type: if ``True``, LoRA params will be of different Param type.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Collaborator

@chiamp chiamp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also add this to Linen?

@copybara-service copybara-service bot merged commit b84be49 into google:main May 23, 2024
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants