Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LoRA to specific layers #427

Closed
LiJunnan1992 opened this issue May 10, 2023 · 3 comments · Fixed by #429
Closed

Add LoRA to specific layers #427

LiJunnan1992 opened this issue May 10, 2023 · 3 comments · Fixed by #429

Comments

@LiJunnan1992
Copy link

Thanks for the great library!

It could be quite useful for many applications to support specifying the layers to insert the adapter. For example, completely freezing some earlier layers could save huge computation cost due to the fewer back-prop layers.

Is there any plan to support this? Or any advice on where should I modify if I want to implement this myself?

Thank you!
Junnan

@younesbelkada
Copy link
Contributor

younesbelkada commented May 10, 2023

Hi @LiJunnan1992 !
Thanks for your message,
Currently we do support LoRA transformation of specific layers, for example this snippet gives:

from peft import get_peft_model, LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m").to(0)

config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=['model.decoder.layers.0.self_attn.v_proj', 'model.decoder.layers.1.self_attn.v_proj', 'model.decoder.layers.2.self_attn.v_proj']
)

model = get_peft_model(model, config)
model.print_trainable_parameters()
# trainable params: 98304 || all params: 331294720 || trainable%: 0.029672673322412142

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m").to(0)

config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none"
)

model = get_peft_model(model, config)
model.print_trainable_parameters()
# trainable params: 1572864 || all params: 332769280 || trainable%: 0.472659014678278

However users needs to manually feed the full name of the modules, does what you had in mind corresponds to explicitly giving the number of layers that you want to "ignore" for LoRA transformation?

cc @pacman100

@LiJunnan1992
Copy link
Author

Thanks for the reply. The current solution already works for me!
For general users, it might be more convenient if the layers can be directly specified in LoraConfig.

@younesbelkada
Copy link
Contributor

Thanks for the suggestion!
I made a PR: #429 that should allow doing this in a more user-friendly manner

from peft import get_peft_model, LoraConfig
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m").to(0)

config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    layers_to_transform=[0, 1, 2],
)

model = get_peft_model(model, config)
model.print_trainable_parameters()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants