Skip to content

Commit

Permalink
[Core] Enhancements and refactoring of LoRA method (#695)
Browse files Browse the repository at this point in the history
* refactor lora and add utils

1. Refactor LoRA code
2. Add method to delete LoRA adapters
3. Add method to unload the PEFT LoRA model.
4. Add `svd` weighted adapter support.
5. minor fixes

* fixes

* fixes

* Update lora.py

* fixes

* Update lora.py

* docstrings for the added public APIs

* docs

* Update src/peft/tuners/lora.py

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* resolve comments, refactoring and adding tests

* fix the remaining failing tests

---------

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
  • Loading branch information
pacman100 and BenjaminBossan authored Jul 14, 2023
1 parent 61a8e3a commit 86ad5ce
Show file tree
Hide file tree
Showing 10 changed files with 450 additions and 111 deletions.
16 changes: 16 additions & 0 deletions docs/source/conceptual_guides/lora.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,22 @@ While LoRA is significantly smaller and faster to train, you may encounter laten

This works because during training, the smaller weight matrices (*A* and *B* in the diagram above) are separate. But once training is complete, the weights can actually be merged into a new weight matrix that is identical.

## Utils for LoRA

Use [`~LoraModel.merge_adapter`] to merge the LoRa layers into the base model while retaining the PeftModel.
This will help in later unmerging, deleting, loading different adapters and so on.

Use [`~LoraModel.unmerge_adapter`] to unmerge the LoRa layers from the base model while retaining the PeftModel.
This will help in later merging, deleting, loading different adapters and so on.

Use [`~LoraModel.unload`] to get back the base model without the merging of the active lora modules.
This will help when you want to get back the pretrained base model in some applications when you want to reset the model to its original state.
For example, in Stable Diffusion WebUi, when the user wants to infer with base model post trying out LoRAs.

Use [`~LoraModel.delete_adapter`] to delete an existing adapter.

Use [`~LoraModel.add_weighted_adapter`] to combine multiple LoRAs into a new adapter based on the user provided weighing scheme.

## Common LoRA parameters in PEFT

As with other methods supported by PEFT, to fine-tune a model using LoRA, you need to:
Expand Down
1 change: 1 addition & 0 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name
"""
model_config = model.config.to_dict() if hasattr(model.config, "to_dict") else model.config
peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)

if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance(
peft_config, PromptLearningConfig
):
Expand Down
13 changes: 11 additions & 2 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,11 @@ def forward(
if peft_config.peft_type == PeftType.PREFIX_TUNING:
past_key_values = self.get_prompt(batch_size)
return self.base_model(
input_ids=input_ids, decoder_input_ids=decoder_input_ids, past_key_values=past_key_values, **kwargs
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
decoder_inputs_embeds=decoder_inputs_embeds,
past_key_values=past_key_values,
**kwargs,
)
elif peft_config.peft_type in [PeftType.PROMPT_TUNING, PeftType.P_TUNING]:
if inputs_embeds is None:
Expand All @@ -1085,7 +1089,12 @@ def forward(
prompts = prompts.to(inputs_embeds.dtype)
inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)

return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
return self.base_model(
inputs_embeds=inputs_embeds,
decoder_input_ids=decoder_input_ids,
decoder_inputs_embeds=decoder_inputs_embeds,
**kwargs,
)
else:
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
Expand Down
Loading

0 comments on commit 86ad5ce

Please sign in to comment.