Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'LoraModel.disable_adapter_layers' not causing the model to use the original modules #493

Closed
glerzing opened this issue May 23, 2023 · 14 comments · Fixed by #736
Closed
Assignees
Labels
bug Something isn't working

Comments

@glerzing
Copy link
Contributor

I am doing a unit test for another project that does this :

  • Create a Lora model using the argument modules_to_save to train one additional layer (e.g., base_model.model.transformer.h.3.mlp.c_proj for gpt2)
  • Make some logit predictions with the initial model
  • Do some training
  • Disable the Lora adapter with the method disable_adapter_layers
  • Do some new logit predictions, and verify that the results are the same as before the training

This test is failing, what I see is that while the Lora layers are indeed bypassed due to the call to LoraModel.disable_adapter_layers, the layer given in modules_to_save is used with its post-training parameters, not its original parameters. The forward pass uses the layer base_model.model.transformer.h.3.mlp.c_proj.modules_to_save instead of the layer base_model.model.transformer.h.3.mlp.c_proj.original_module.

peft version : 0.3.0

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@glerzing
Copy link
Contributor Author

This doesn't seem solved yet from what I see in ModulesToSaveWrapper. When the adapter is disable, we probably still use original_module instead of modules_to_save.

@younesbelkada
Copy link
Contributor

Hi @glerzing
Thanks for the issue, can you share a handy reproducible small snippet that best describes your problem?

@glerzing
Copy link
Contributor Author

Here it is :

from peft import LoraConfig, get_peft_model, TaskType
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=32,
    lora_dropout=0.0,
    modules_to_save=["base_model.model.transformer.h.3.mlp"],
)

tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-1M")
model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M")
model = get_peft_model(model, peft_config)

inputs = tokenizer("Hello world", return_tensors="pt")

initial_logits = model(**inputs, return_dict=True).logits

# Cause whatever loss to be backpropagated. The aim is just to change the parameter values.
loss = torch.nn.functional.binary_cross_entropy_with_logits(initial_logits[0][-1][:1], torch.tensor([0.53]))
loss.backward()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.step()

# Check that the backpropagation worked
new_logits = model(**inputs, return_dict=True).logits
assert not torch.equal(initial_logits, new_logits)

model.disable_adapter_layers()
logits_without_adapter = model(**inputs, return_dict=True).logits
# Should trigger an error if the problem is not solved
assert torch.equal(initial_logits, logits_without_adapter)

@younesbelkada younesbelkada self-assigned this Jun 23, 2023
@younesbelkada younesbelkada added the bug Something isn't working label Jun 23, 2023
@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member

I could reproduce the issue, but commenting out this line fixed it:

modules_to_save=["base_model.model.transformer.h.3.mlp"]

The problem is that this layer will be updated during training in the usual fashion, i.e. without LoRA. Therefore, disabling the LoRA adapter will not revert the changes to that specific layer, which explains the different results.

@glerzing
Copy link
Contributor Author

The aim of my test is indeed to verify the correct behavior of disable_adapter_layers on a model with modules_to_save.

Isn't the point of having duplicated the layers into modules_to_save and original_module to be able to disable the adapter with disable_adapter_layers and re-enable it ? It's practical, because it means that you can use both the original model and the fine-tuned model in one single model, without doubling the memory requirement. Only the fine-tuned, non-Lora layers are duplicated.

@BenjaminBossan
Copy link
Member

Yes, I think you're correct. I'll be adding a fix for this, thanks for digging into it.

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this issue Jul 20, 2023
Resolves huggingface#493

For LoRA and IA³, there was a bug that even even using the
disable_adapter context, if the module was listed in modules_to_save,
the updated weights would be used instead of the original weights. This
meant that disable_adapter would not return the same results as the base
model without adaptation. This PR fixes the issue and provides a test.

Note: I tried to adjust AdaLoRA too, since it seemed that the same
reasoning should apply there. However, I think that AdaLoRA does not
really support disabling adapters at all. E.g. there is no
disable_adapter_layers method. Therefore, AdaLoRA was not changed.
@BenjaminBossan
Copy link
Member

@glerzing See #736

BenjaminBossan added a commit that referenced this issue Jul 24, 2023
Resolves #493

For LoRA and IA³, there was a bug that even even using the
disable_adapter context, if the module was listed in modules_to_save,
the updated weights would be used instead of the original weights. This
meant that disable_adapter would not return the same results as the base
model without adaptation. This PR fixes the issue and provides a test.

Note: I tried to adjust AdaLoRA too, since it seemed that the same
reasoning should apply there. However, I think that AdaLoRA does not
really support disabling adapters at all. E.g. there is no
disable_adapter_layers method. Therefore, AdaLoRA was not changed.
@CRLqinliang
Copy link

CRLqinliang commented Aug 24, 2023

Hi @glerzing, I am so glad you have already asked this question.
So, does that mean I could use .disable_adapter_layers () to get the original output and use enable_adapter_layers() to get LoRA output, right?
The original model is frozen no matter disabled or enabled, am I right? Cause I don't want to change the original parameters.
@BenjaminBossan

@BenjaminBossan
Copy link
Member

So, does that mean I could use .disable_adapter_layers () to get the original output and use enable_adapter_layers() to get LoRA output, right?

Yes. If that doesn't work correctly, please let us know.

The original model is frozen no matter disabled or enabled, am I right? Cause I don't want to change the original parameters.

Yes, freezing is not affected by disabling. When you disable the adapter, you shouldn't train anyway, there would be no point.

@CRLqinliang
Copy link

CRLqinliang commented Sep 1, 2023

Okay, I see.
By the way, Could I enable just the encoder part LoRA of LLM, and disable the decoder part LoRA?
@BenjaminBossan

@BenjaminBossan
Copy link
Member

No, there is no option currently to only enable part of an adapter. I guess what you could try out is to load the LoRA adapter state dict manually with torch.load, delete all the items that you don't want, and then load only the remaining items into the peft model by using set_peft_model_state_dict.

@AntoineBlanot
Copy link

def disable_encoder_adapters(model: AutoPeftModel):
        for name, module in model.named_modules():
            if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
                if "encoder" in name:
                    module.enable_adapters(enabled=False)
                else:
                    print(name, module)
                    module.enable_adapters(enabled=True)
            return model

model = disable_encoder_adapters(model)

This code should only enable adapter layers that are NOT in the encoder (in the case of encoder-decocder models, it will only enable the layers in the decoder). You can change it as you wish (for example if you only want to disable a specific layer)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants