diff --git a/src/peft/tuners/boft/layer.py b/src/peft/tuners/boft/layer.py index 97a1baaa58..604f83fcae 100644 --- a/src/peft/tuners/boft/layer.py +++ b/src/peft/tuners/boft/layer.py @@ -520,7 +520,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - self.base_layer.weight.data = orig_weight + self.base_layer.weight.data = orig_weight.contiguous() else: butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) orig_weight = base_layer.weight.data.clone() @@ -529,7 +529,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N orig_weight = torch.transpose(orig_weight, 0, 1) orig_weight = orig_weight * boft_s - self.base_layer.weight.data = orig_weight + self.base_layer.weight.data = orig_weight.contiguous() self.merged_adapters.append(active_adapter) @@ -817,7 +817,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0] ) - self.base_layer.weight.data = orig_weight + self.base_layer.weight.data = orig_weight.contiguous() else: butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) @@ -831,7 +831,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0] ) - self.base_layer.weight.data = orig_weight + self.base_layer.weight.data = orig_weight.contiguous() self.merged_adapters.append(active_adapter) diff --git a/src/peft/tuners/boft/model.py b/src/peft/tuners/boft/model.py index 0cb3a92915..11bd4c3ad2 100644 --- a/src/peft/tuners/boft/model.py +++ b/src/peft/tuners/boft/model.py @@ -24,7 +24,12 @@ from torch import nn from tqdm import tqdm -from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists +from peft.tuners.tuners_utils import ( + BaseTuner, + BaseTunerLayer, + check_target_module_exists, + onload_layer, +) from peft.utils import ( TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, ModulesToSaveWrapper, @@ -265,7 +270,9 @@ def _unload_and_optionally_merge( safe_merge: bool = False, adapter_names: Optional[List[str]] = None, ): - self._unloading_checks(adapter_names) + if merge: + self._check_merge_allowed() + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] desc = "Unloading " + ("and merging " if merge else "") + "model" for key in tqdm(key_list, disable=not progressbar, desc=desc): @@ -273,14 +280,20 @@ def _unload_and_optionally_merge( parent, target, target_name = _get_submodules(self.model, key) except AttributeError: continue - - if hasattr(target, "base_layer"): - if merge: - target.merge(safe_merge=safe_merge, adapter_names=adapter_names) - self._replace_module(parent, target_name, target.get_base_layer(), target) - elif isinstance(target, ModulesToSaveWrapper): - # save any additional trainable modules part of `modules_to_save` - setattr(parent, target_name, target.modules_to_save[target.active_adapter]) + with onload_layer(target): + if hasattr(target, "base_layer"): + if merge: + target.merge(safe_merge=safe_merge, adapter_names=adapter_names) + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + new_module = target.modules_to_save[target.active_adapter] + if hasattr(new_module, "base_layer"): + # check if the module is itself a tuner layer + if merge: + new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names) + new_module = new_module.get_base_layer() + setattr(parent, target_name, new_module) return self.model diff --git a/src/peft/tuners/oft/layer.py b/src/peft/tuners/oft/layer.py index 8f973bda8b..965f2e83ff 100644 --- a/src/peft/tuners/oft/layer.py +++ b/src/peft/tuners/oft/layer.py @@ -171,7 +171,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - base_layer.weight.data = new_weights + base_layer.weight.data = new_weights.contiguous() self.merged_adapters.append(active_adapter) def unmerge(self) -> None: @@ -215,7 +215,7 @@ def unmerge(self) -> None: base_layer.kernel_size[1], ] ) - base_layer.weight.data = orig_weights + base_layer.weight.data = orig_weights.contiguous() def get_delta_weight(self, adapter_name: str) -> torch.Tensor: rank = self.r[adapter_name] diff --git a/tests/testing_common.py b/tests/testing_common.py index 9168b54b5a..da58337bc2 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -26,6 +26,7 @@ import yaml from diffusers import StableDiffusionPipeline from packaging import version +from safetensors.torch import save_file from peft import ( AdaLoraConfig, @@ -763,6 +764,14 @@ def _test_safe_merge(self, model_id, config_cls, config_kwargs): # check that the logits are the same after unloading assert torch.allclose(logits_peft, logits_unloaded, atol=atol, rtol=rtol) + # Ensure that serializing with safetensors works, there was an error when weights were not contiguous + with tempfile.TemporaryDirectory() as tmp_dirname: + # serializing with torch.save works + torch.save(model_unloaded.state_dict(), os.path.join(tmp_dirname, "model.bin")) + + # serializing with safetensors works + save_file(model_unloaded.state_dict(), os.path.join(tmp_dirname, "model.safetensors")) + def _test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs): # Test for mixing different adapters in a single batch by passing the adapter_names argument if config_cls not in (LoraConfig,):