-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
'LoraModel.disable_adapter_layers' not causing the model to use the original modules #493
Comments
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
This doesn't seem solved yet from what I see in |
Hi @glerzing |
Here it is : from peft import LoraConfig, get_peft_model, TaskType
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=32,
lora_dropout=0.0,
modules_to_save=["base_model.model.transformer.h.3.mlp"],
)
tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-1M")
model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M")
model = get_peft_model(model, peft_config)
inputs = tokenizer("Hello world", return_tensors="pt")
initial_logits = model(**inputs, return_dict=True).logits
# Cause whatever loss to be backpropagated. The aim is just to change the parameter values.
loss = torch.nn.functional.binary_cross_entropy_with_logits(initial_logits[0][-1][:1], torch.tensor([0.53]))
loss.backward()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.step()
# Check that the backpropagation worked
new_logits = model(**inputs, return_dict=True).logits
assert not torch.equal(initial_logits, new_logits)
model.disable_adapter_layers()
logits_without_adapter = model(**inputs, return_dict=True).logits
# Should trigger an error if the problem is not solved
assert torch.equal(initial_logits, logits_without_adapter) |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
I could reproduce the issue, but commenting out this line fixed it:
The problem is that this layer will be updated during training in the usual fashion, i.e. without LoRA. Therefore, disabling the LoRA adapter will not revert the changes to that specific layer, which explains the different results. |
The aim of my test is indeed to verify the correct behavior of Isn't the point of having duplicated the layers into |
Yes, I think you're correct. I'll be adding a fix for this, thanks for digging into it. |
Resolves huggingface#493 For LoRA and IA³, there was a bug that even even using the disable_adapter context, if the module was listed in modules_to_save, the updated weights would be used instead of the original weights. This meant that disable_adapter would not return the same results as the base model without adaptation. This PR fixes the issue and provides a test. Note: I tried to adjust AdaLoRA too, since it seemed that the same reasoning should apply there. However, I think that AdaLoRA does not really support disabling adapters at all. E.g. there is no disable_adapter_layers method. Therefore, AdaLoRA was not changed.
Resolves #493 For LoRA and IA³, there was a bug that even even using the disable_adapter context, if the module was listed in modules_to_save, the updated weights would be used instead of the original weights. This meant that disable_adapter would not return the same results as the base model without adaptation. This PR fixes the issue and provides a test. Note: I tried to adjust AdaLoRA too, since it seemed that the same reasoning should apply there. However, I think that AdaLoRA does not really support disabling adapters at all. E.g. there is no disable_adapter_layers method. Therefore, AdaLoRA was not changed.
Hi @glerzing, I am so glad you have already asked this question. |
Yes. If that doesn't work correctly, please let us know.
Yes, freezing is not affected by disabling. When you disable the adapter, you shouldn't train anyway, there would be no point. |
Okay, I see. |
No, there is no option currently to only enable part of an adapter. I guess what you could try out is to load the LoRA adapter state dict manually with |
This code should only enable adapter layers that are NOT in the encoder (in the case of encoder-decocder models, it will only enable the layers in the decoder). You can change it as you wish (for example if you only want to disable a specific layer) |
I am doing a unit test for another project that does this :
base_model.model.transformer.h.3.mlp.c_proj
for gpt2)This test is failing, what I see is that while the Lora layers are indeed bypassed due to the call to LoraModel.disable_adapter_layers, the layer given in modules_to_save is used with its post-training parameters, not its original parameters. The forward pass uses the layer
base_model.model.transformer.h.3.mlp.c_proj.modules_to_save
instead of the layerbase_model.model.transformer.h.3.mlp.c_proj.original_module
.peft version : 0.3.0
The text was updated successfully, but these errors were encountered: