-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LoRA to specific layers #427
Comments
Hi @LiJunnan1992 ! from peft import get_peft_model, LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m").to(0)
config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
target_modules=['model.decoder.layers.0.self_attn.v_proj', 'model.decoder.layers.1.self_attn.v_proj', 'model.decoder.layers.2.self_attn.v_proj']
)
model = get_peft_model(model, config)
model.print_trainable_parameters()
# trainable params: 98304 || all params: 331294720 || trainable%: 0.029672673322412142
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m").to(0)
config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none"
)
model = get_peft_model(model, config)
model.print_trainable_parameters()
# trainable params: 1572864 || all params: 332769280 || trainable%: 0.472659014678278 However users needs to manually feed the full name of the modules, does what you had in mind corresponds to explicitly giving the number of layers that you want to "ignore" for LoRA transformation? cc @pacman100 |
Thanks for the reply. The current solution already works for me! |
Thanks for the suggestion! from peft import get_peft_model, LoraConfig
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m").to(0)
config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
layers_to_transform=[0, 1, 2],
)
model = get_peft_model(model, config)
model.print_trainable_parameters() |
Thanks for the great library!
It could be quite useful for many applications to support specifying the layers to insert the adapter. For example, completely freezing some earlier layers could save huge computation cost due to the fewer back-prop layers.
Is there any plan to support this? Or any advice on where should I modify if I want to implement this myself?
Thank you!
Junnan
The text was updated successfully, but these errors were encountered: