Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BOFT bug fix when saving #1994

Merged
merged 9 commits into from
Aug 7, 2024
8 changes: 4 additions & 4 deletions src/peft/tuners/boft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)

self.base_layer.weight.data = orig_weight
self.base_layer.weight.data = orig_weight.contiguous()
else:
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)
orig_weight = base_layer.weight.data.clone()
Expand All @@ -529,7 +529,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = orig_weight * boft_s

self.base_layer.weight.data = orig_weight
self.base_layer.weight.data = orig_weight.contiguous()

self.merged_adapters.append(active_adapter)

Expand Down Expand Up @@ -817,7 +817,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
)

self.base_layer.weight.data = orig_weight
self.base_layer.weight.data = orig_weight.contiguous()
else:
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)

Expand All @@ -831,7 +831,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
)

self.base_layer.weight.data = orig_weight
self.base_layer.weight.data = orig_weight.contiguous()

self.merged_adapters.append(active_adapter)

Expand Down
33 changes: 23 additions & 10 deletions src/peft/tuners/boft/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from torch import nn
from tqdm import tqdm

from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists
from peft.tuners.tuners_utils import (
BaseTuner,
BaseTunerLayer,
check_target_module_exists,
onload_layer,
)
from peft.utils import (
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
ModulesToSaveWrapper,
Expand Down Expand Up @@ -265,22 +270,30 @@ def _unload_and_optionally_merge(
safe_merge: bool = False,
adapter_names: Optional[List[str]] = None,
):
self._unloading_checks(adapter_names)
if merge:
self._check_merge_allowed()

key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
desc = "Unloading " + ("and merging " if merge else "") + "model"
for key in tqdm(key_list, disable=not progressbar, desc=desc):
try:
parent, target, target_name = _get_submodules(self.model, key)
except AttributeError:
continue

if hasattr(target, "base_layer"):
if merge:
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
self._replace_module(parent, target_name, target.get_base_layer(), target)
elif isinstance(target, ModulesToSaveWrapper):
# save any additional trainable modules part of `modules_to_save`
setattr(parent, target_name, target.modules_to_save[target.active_adapter])
with onload_layer(target):
if hasattr(target, "base_layer"):
if merge:
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
self._replace_module(parent, target_name, target.get_base_layer(), target)
elif isinstance(target, ModulesToSaveWrapper):
# save any additional trainable modules part of `modules_to_save`
new_module = target.modules_to_save[target.active_adapter]
if hasattr(new_module, "base_layer"):
# check if the module is itself a tuner layer
if merge:
new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names)
new_module = new_module.get_base_layer()
setattr(parent, target_name, new_module)

return self.model

Expand Down
4 changes: 2 additions & 2 deletions src/peft/tuners/oft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)

base_layer.weight.data = new_weights
base_layer.weight.data = new_weights.contiguous()
self.merged_adapters.append(active_adapter)

def unmerge(self) -> None:
Expand Down Expand Up @@ -215,7 +215,7 @@ def unmerge(self) -> None:
base_layer.kernel_size[1],
]
)
base_layer.weight.data = orig_weights
base_layer.weight.data = orig_weights.contiguous()

def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
rank = self.r[adapter_name]
Expand Down
4 changes: 4 additions & 0 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,10 @@ def _test_safe_merge(self, model_id, config_cls, config_kwargs):
# check that the logits are the same after unloading
assert torch.allclose(logits_peft, logits_unloaded, atol=atol, rtol=rtol)

# serializing with safetensors works
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# serializing with safetensors works
# Ensure that serializing with safetensors works, there was an error when weights were not contiguous

from safetensors.torch import save_file
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this import to the top of the file.

save_file(model_unloaded.state_dict(), os.path.join(tmp_dirname, "model.safetensors"))

def _test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs):
# Test for mixing different adapters in a single batch by passing the adapter_names argument
if config_cls not in (LoraConfig,):
Expand Down