-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NNX] Add LoRA
and LoRALinear
to NNX
#3929
Conversation
9298031
to
d22505d
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3929 +/- ##
==========================================
- Coverage 60.43% 0.00% -60.44%
==========================================
Files 105 102 -3
Lines 13263 13160 -103
==========================================
- Hits 8015 0 -8015
- Misses 5248 13160 +7912 ☔ View full report in Codecov by Sentry. |
flax/experimental/nnx/nnx/nn/lora.py
Outdated
def __call__(self, x: jax.Array): | ||
out = x @ self.lora_a @ self.lora_b | ||
if self.base_module is not None: | ||
assert callable(self.base_module), "`base_module` must be callable." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets raise an error here instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
return out | ||
|
||
|
||
class LoRALinear(nnx.Linear): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might cleaner to inherit from LoRA
and create a Linear
in __init__
which is passed as the base_module
to super().__init__()
. This way you can remove __call__
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I provided LoRALinear
is that this assumes the exact same param structure and API as Linear
. This is also how PyTorch provides their LoRALinear
.
If it's just a LoRA
instance, the original linear weights will be one level below, inside base_module
. I think users can easily create something like this on their own, no need to have a LoRALinear
shortcut for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense.
flax/experimental/nnx/nnx/nn/lora.py
Outdated
in_features: int, | ||
out_features: int, | ||
lora_rank: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this ordering is a bit more intuitive?
in_features: int, | |
out_features: int, | |
lora_rank: int, | |
in_features: int, | |
lora_rank: int, | |
out_features: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! LoRA is a generic layer and this is indeed more intuitive.
Note that the PyTorch LoRALinear implementation is ordered (in, out, rank)
. To avoid this being confusing, I will make lora_rank
a kwarg argument in our LoRALinear
.
flax/experimental/nnx/nnx/nn/lora.py
Outdated
self.lora_a = param_type( | ||
kernel_init(rngs.params(), (in_features, lora_rank), param_dtype) | ||
) | ||
self.lora_b = param_type( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if there are more informative variable names we can use like self.lora_in
and self.lora_out
, or is a
and b
standard naming conventions for LoRA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah unfortunately AFAIK a
and b
are the conventions here...
flax/experimental/nnx/nnx/nn/lora.py
Outdated
precision: numerical precision of the computation see `jax.lax.Precision` | ||
for details. | ||
kernel_init: initializer function for the weight matrices. | ||
use_lora_param_type: if yes, LoRA params will be of different Param type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_lora_param_type: if yes, LoRA params will be of different Param type. | |
use_lora_param_type: if ``True``, LoRA params will be of different Param type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am thinking maybe it's better to make this argument an explicit type instead of a boolean... more customizability!
flax/experimental/nnx/nnx/nn/lora.py
Outdated
precision: numerical precision of the computation see `jax.lax.Precision` | ||
for details. | ||
kernel_init: initializer function for the weight matrices. | ||
use_lora_param_type: if yes, LoRA params will be of different Param type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_lora_param_type: if yes, LoRA params will be of different Param type. | |
use_lora_param_type: if ``True``, LoRA params will be of different Param type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also add this to Linen?
Provided here are two ways to add LoRA on any layer:
nnx.LoRA
withbase_module
arg that can take any layer instance. Easy for model surgery on modules, but will render mismatchingstate
s.nnx.LoRALinear
which subclassesnnx.Linear
and attach a simplennx.LoRA
module along the way. Param structure matches withnnx.Linear
, but need a bit more surgery on runtime. This technique can be applied for any other NNX modules.Not yet sure which approach is the best, so providing both at this moment.