Skip to content

Commit

Permalink
[utils] add merge_lora utility function (#227)
Browse files Browse the repository at this point in the history
* add merge_lora utility function

* forward contrib credits from original script

* some changes

* make style

* fix tets

* finally fix tests

* Update tests/test_peft_model.py

* adapt from suggestions

* adapt

* Update src/peft/tuners/lora.py

Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

* fix 8bit

* Update src/peft/tuners/lora.py

Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

---------

Co-authored-by: edbeeching <edbeeching@users.noreply.github.com>
Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
  • Loading branch information
3 people authored Mar 30, 2023
1 parent 542f247 commit 8f63f56
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 39 deletions.
47 changes: 45 additions & 2 deletions src/peft/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class LoraConfig(PeftConfig):
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
},
)
init_lora_weights: bool = field(
default=True,
metadata={"help": "Whether to initialize the weights of the Lora layers."},
)

def __post_init__(self):
self.peft_type = PeftType.LORA
Expand Down Expand Up @@ -135,6 +139,7 @@ def _find_and_replace(self):
"fan_in_fan_out": self.peft_config.fan_in_fan_out,
"merge_weights": (self.peft_config.merge_weights or self.peft_config.inference_mode)
and not is_hf_device_map_available,
"init_lora_weights": self.peft_config.init_lora_weights,
}
key_list = [key for key, _ in self.model.named_modules()]
for key in key_list:
Expand Down Expand Up @@ -233,6 +238,37 @@ def enable_adapter_layers(self):
def disable_adapter_layers(self):
self._set_adapter_layers(enabled=False)

def merge_and_unload(self):
r"""
This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model
as a standalone model.
"""
if self.config.model_type == "gpt2":
raise ValueError("GPT2 models are not supported for merging LORA layers")

if getattr(self.model, "is_loaded_in_8bit", False):
raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode")

key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
for key in key_list:
parent, target, target_name = self._get_submodules(key)
if isinstance(target, LoraLayer):
bias = target.bias is not None
new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)

# manually merge if not merged
if not target.merged:
# merge weights per: https://arxiv.org/pdf/2106.09685.pdf / page 4
if target.r > 0:
target.weight.data += (
transpose(target.lora_B.weight @ target.lora_A.weight, target.fan_in_fan_out)
* target.scaling
).to(target.weight.dtype)
target.merged = True

self._replace_module(parent, target_name, new_module, target)
return self.model


# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# and modified to work with PyTorch FSDP
Expand Down Expand Up @@ -297,6 +333,8 @@ def __init__(
merge_weights: bool = True,
**kwargs,
):
init_lora_weights = kwargs.pop("init_lora_weights", True)

nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)

Expand All @@ -308,7 +346,8 @@ def __init__(
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if init_lora_weights:
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T

Expand Down Expand Up @@ -375,6 +414,8 @@ def __init__(
merge_weights: bool = True,
**kwargs,
):
init_lora_weights = kwargs.pop("init_lora_weights", True)

nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
if out_features % len(enable_lora) != 0:
Expand All @@ -398,7 +439,9 @@ def __init__(
self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1)
self.lora_ind[enable_lora, :] = True
self.lora_ind = self.lora_ind.view(-1)
self.reset_parameters()

if init_lora_weights:
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T

Expand Down
83 changes: 68 additions & 15 deletions tests/test_peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,19 @@
from .testing_common import PeftTestConfigManager


# This has to be in the order: model_id, lora_kwargs, prefix_tuning_kwargs, prompt_encoder_kwargs, prompt_tuning_kwargs
PEFT_DECODER_MODELS_TO_TEST = [
# ("HuggingFaceM4/tiny-random-LlamaForCausalLM", {}, {}, {}, {}), wait until the next `transformers` release
("hf-internal-testing/tiny-random-OPTForCausalLM", {}, {}, {}, {}),
("hf-internal-testing/tiny-random-GPTNeoXForCausalLM", {}, {}, {}, {}),
("hf-internal-testing/tiny-random-GPT2LMHeadModel", {}, {}, {}, {}),
("hf-internal-testing/tiny-random-BloomForCausalLM", {}, {}, {}, {}),
("hf-internal-testing/tiny-random-gpt_neo", {}, {}, {}, {}),
("hf-internal-testing/tiny-random-GPTJForCausalLM", {}, {}, {}, {}),
"hf-internal-testing/tiny-random-OPTForCausalLM",
"hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"hf-internal-testing/tiny-random-GPT2LMHeadModel",
"hf-internal-testing/tiny-random-BloomForCausalLM",
"hf-internal-testing/tiny-random-gpt_neo",
"hf-internal-testing/tiny-random-GPTJForCausalLM",
]

FULL_GRID = {
"model_ids": PEFT_DECODER_MODELS_TO_TEST,
}


class PeftTestMixin:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
Expand All @@ -54,10 +56,6 @@ class PeftModelTester(unittest.TestCase, PeftTestMixin):
We use parametrized.expand for debugging purposes to test each model individually.
"""

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_DECODER_MODELS_TO_TEST))
def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs):
self._test_model_attr(model_id, config_cls, config_kwargs)

def _test_model_attr(self, model_id, config_cls, config_kwargs):
model = AutoModelForCausalLM.from_pretrained(model_id)
config = config_cls(
Expand All @@ -70,6 +68,10 @@ def _test_model_attr(self, model_id, config_cls, config_kwargs):
self.assertTrue(hasattr(model, "from_pretrained"))
self.assertTrue(hasattr(model, "push_to_hub"))

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs):
self._test_model_attr(model_id, config_cls, config_kwargs)

def _test_prepare_for_training(self, model_id, config_cls, config_kwargs):
model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device)
config = config_cls(
Expand Down Expand Up @@ -111,7 +113,7 @@ def make_inputs_require_grad(module, input, output):

self.assertTrue(dummy_output.requires_grad)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_DECODER_MODELS_TO_TEST))
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls, config_kwargs):
self._test_prepare_for_training(model_id, config_cls, config_kwargs)

Expand Down Expand Up @@ -157,10 +159,61 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs):
# check if `config.json` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_DECODER_MODELS_TO_TEST))
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs):
self._test_save_pretrained(model_id, config_cls, config_kwargs)

def _test_merge_layers(self, model_id, config_cls, config_kwargs):
model = AutoModelForCausalLM.from_pretrained(model_id)
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
model = model.to(self.torch_device)

if config.peft_type != "LORA":
with self.assertRaises(AttributeError):
model = model.merge_and_unload()
elif model.config.model_type == "gpt2":
with self.assertRaises(ValueError):
model = model.merge_and_unload()
else:
dummy_input = torch.LongTensor([[1, 2, 3, 2, 1]]).to(self.torch_device)
model.eval()
logits_lora = model(dummy_input)[0]

model = model.merge_and_unload()

logits_merged = model(dummy_input)[0]

transformers_model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device)

logits_transformers = transformers_model(dummy_input)[0]

self.assertTrue(torch.allclose(logits_lora, logits_merged, atol=1e-3, rtol=1e-3))
self.assertFalse(torch.allclose(logits_merged, logits_transformers, atol=1e-3, rtol=1e-3))

with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)

model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmp_dirname).to(self.torch_device)

logits_merged_from_pretrained = model_from_pretrained(dummy_input)[0]

self.assertTrue(torch.allclose(logits_merged, logits_merged_from_pretrained, atol=1e-3, rtol=1e-3))

@parameterized.expand(
PeftTestConfigManager.get_grid_parameters(
{
"model_ids": PEFT_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False], "merge_weights": [False, True]},
},
)
)
def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
self._test_merge_layers(model_id, config_cls, config_kwargs)

def _test_generate(self, model_id, config_cls, config_kwargs):
model = AutoModelForCausalLM.from_pretrained(model_id)
config = config_cls(
Expand All @@ -180,6 +233,6 @@ def _test_generate(self, model_id, config_cls, config_kwargs):
# check if `generate` raises an error if no positional arguments are passed
_ = model.generate(input_ids, attention_mask=attention_mask)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_DECODER_MODELS_TO_TEST))
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_generate(self, test_name, model_id, config_cls, config_kwargs):
self._test_generate(model_id, config_cls, config_kwargs)
57 changes: 35 additions & 22 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,34 +71,47 @@ def __getitem__(self, key, *args, **kwargs):

return super().__getitem__(key, *args, **kwargs)

def get_grid_parameters(self, model_list):
def get_grid_parameters(self, grid_parameters, filter_params_func=None):
r"""
Returns a list of all possible combinations of the parameters in the config classes.
Args:
grid_parameters (`dict`):
A dictionary containing the parameters to be tested. There should be at least the key "model_ids" which
contains a list of model ids to be tested. The other keys should be the name of the config class
post-fixed with "_kwargs" and the value should be a dictionary containing the parameters to be tested
for that config class.
filter_params_func (`callable`, `optional`):
A function that takes a list of tuples and returns a list of tuples. This function is used to filter
out the tests that needs for example to be skipped.
Returns:
generated_tests (`list`):
A list of tuples containing the name of the test, the model id, the config class and the config class
kwargs.
"""
grid_parameters = []
for model_tuple in model_list:
model_id, lora_kwargs, prefix_tuning_kwargs, prompt_encoder_kwargs, prompt_tuning_kwargs = model_tuple
generated_tests = []
model_list = grid_parameters["model_ids"]

for model_id in model_list:
for key, value in self.items():
peft_method = value[1].copy()
if key == "lora":
# update value[1] if necessary
if lora_kwargs is not None:
peft_method.update(lora_kwargs)
elif key == "prefix_tuning":
# update value[1] if necessary
if prefix_tuning_kwargs is not None:
peft_method.update(prefix_tuning_kwargs)
elif key == "prompt_encoder":
# update value[1] if necessary
if prompt_encoder_kwargs is not None:
peft_method.update(prompt_encoder_kwargs)
if "{}_kwargs".format(key) in grid_parameters:
peft_configs = []
current_peft_config = value[1].copy()
for current_key, current_value in grid_parameters[f"{key}_kwargs"].items():
for kwarg in current_value:
current_peft_config.update({current_key: kwarg})
peft_configs.append(current_peft_config)
else:
# update value[1] if necessary
if prompt_tuning_kwargs is not None:
peft_method.update(prompt_tuning_kwargs)
grid_parameters.append((f"test_{model_id}_{key}", model_id, value[0], peft_method))
peft_configs = [value[1].copy()]

for peft_config in peft_configs:
generated_tests.append((f"test_{model_id}_{key}", model_id, value[0], peft_config))

if filter_params_func is not None:
generated_tests = filter_params_func(generated_tests)

return grid_parameters
return generated_tests


PeftTestConfigManager = ClassInstantier(CLASSES_MAPPING)

0 comments on commit 8f63f56

Please sign in to comment.