Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add progressbar unload/merge #753

Merged
merged 5 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"pyyaml",
"torch>=1.13.0",
"transformers",
"tqdm",
"accelerate",
"safetensors",
],
Expand Down
16 changes: 11 additions & 5 deletions src/peft/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from transformers.pytorch_utils import Conv1D

from ..import_utils import is_bnb_4bit_available, is_bnb_available
Expand Down Expand Up @@ -53,7 +54,8 @@ class LoraConfig(PeftConfig):
lora_alpha (`int`): The alpha parameter for Lora scaling.
lora_dropout (`float`): The dropout probability for Lora layers.
fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out).
For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.:
For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set
to `True`.
bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only'. If 'all' or 'lora_only', the
corresponding biases will be updated during training. Be aware that this means that, even when disabling
the adapters, the model will not produce the same output as the base model would have without adaptation.
Expand Down Expand Up @@ -458,12 +460,13 @@ def _prepare_lora_config(peft_config, model_config):
peft_config.target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]]
return peft_config

def _unload_and_optionally_merge(self, merge=True):
def _unload_and_optionally_merge(self, merge=True, progressbar: bool = False):
if getattr(self.model, "is_loaded_in_8bit", False) or getattr(self.model, "is_loaded_in_4bit", False):
raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode")

key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
for key in key_list:
desc = "Unloading " + ("and merging " if merge else "") + "model"
for key in tqdm(key_list, disable=not progressbar, desc=desc):
try:
parent, target, target_name = _get_submodules(self.model, key)
except AttributeError:
Expand Down Expand Up @@ -621,11 +624,14 @@ def delete_adapter(self, adapter_name):
)
target.active_adapter = resetting_active_adapter

def merge_and_unload(self):
def merge_and_unload(self, progressbar: bool = False):
r"""
This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model
as a standalone model.

Args:
progressbar (bool): whether to show a progressbar indicating the unload and merge process

Example:

```py
Expand All @@ -638,7 +644,7 @@ def merge_and_unload(self):
>>> merged_model = model.merge_and_unload()
```
"""
return self._unload_and_optionally_merge()
return self._unload_and_optionally_merge(progressbar=progressbar)

def unload(self):
"""
Expand Down