Skip to content

Commit

Permalink
FIX Multiple adapters and modules_to_save (#1615)
Browse files Browse the repository at this point in the history
Previously, we had the bug that if we had multiple adapters, some with
modules_to_save and others without, when trying to switch to an adapter
without modules_to_save, the ModulesToSaveWrapper would raise an error
because it cannot find that adapter. Now, when it detects this, it is
just disabled (so it uses the original weight).

Moreover, we had the issue that when we were using classes such as
PeftModelForSequenceClassification, we implicitly added the classifier
layers to model.modules_to_save. However, this would only add a new
ModulesToSaveWrapper instance for the first adapter being initialized.
When initializing a 2nd adapter via add_adapter, this information was
ignored. To fix this, I now update the peft_config.modules_to_save to
explicitly add the classifier layers. This is a departure from how this
worked previously, but I'm couldn't find a better way to ensure that
this bug was fixed.

Finally, there was a bug in add_weighted_adapters when we were merging
multiple adapters with modules_to_save. Previously, when we called
model.add_weighted_adapter, the LoRA weights were merged and a new
ModulesToSaveWrapper was added for the new adapter based on the first
LoraConfig of the two adapters. This ModulesToSaveWrapper is just a copy
of the original weights. Thus, when we switch to the newly merged
adapter, we just use the original weights for modules_to_save. This
doesn't make a lot of sense and is probably surprising for the user.
Now, we raise an error when we detect this to alert the user to this
fact.

Note that when only one of the adapters to be merged has a
modules_to_save, this does not raise an error, instead that module is
being used.
  • Loading branch information
BenjaminBossan authored Apr 9, 2024
1 parent e07095a commit 0d283ae
Show file tree
Hide file tree
Showing 4 changed files with 344 additions and 61 deletions.
127 changes: 116 additions & 11 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def set_additional_trainable_modules(self, peft_config, adapter_name):
self.modules_to_save = set(peft_config.modules_to_save)
else:
self.modules_to_save.update(peft_config.modules_to_save)
_set_trainable(self, adapter_name)
_set_trainable(self, adapter_name) # this may add a new ModulesToSaveWrapper

@classmethod
def _split_kwargs(cls, kwargs: dict[str, Any]):
Expand All @@ -714,7 +714,7 @@ def _split_kwargs(cls, kwargs: dict[str, Any]):

return hf_hub_download_kwargs, other_kwargs

def _update_offload(self, offload_index: dict[dict[str:str]], adapters_weights: dict[str : torch.tensor]):
def _update_offload(self, offload_index: dict[str, dict[str, str]], adapters_weights: dict[str, torch.tensor]):
"""
Update the offload_index and safetensors files for loading and mergine PeftModels with disk-offloaded modules.
Expand Down Expand Up @@ -1023,19 +1023,54 @@ class PeftModelForSequenceClassification(PeftModel):

def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
super().__init__(model, peft_config, adapter_name)

classifier_module_names = ["classifier", "score"]
if self.modules_to_save is None:
self.modules_to_save = {"classifier", "score"}
self.modules_to_save = set(classifier_module_names)
else:
self.modules_to_save.update({"classifier", "score"})
self.modules_to_save.update(classifier_module_names)

if hasattr(peft_config, "modules_to_save"):
if peft_config.modules_to_save is None:
peft_config.modules_to_save = classifier_module_names[:]
else:
peft_config.modules_to_save.extend(classifier_module_names)

for name, _ in self.base_model.named_children():
if any(module_name in name for module_name in self.modules_to_save):
self.cls_layer_name = name
break

# to make sure classifier layer is trainable
# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name)

def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None:
"""
Add an adapter to the model based on the passed configuration.
This adapter is not trained. To load a trained adapter, check out [`PeftModel.load_adapter`].
The name for the new adapter should be unique.
The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active
adapter.
Args:
adapter_name (`str`):
The name of the adapter to be added.
peft_config ([`PeftConfig`]):
The configuration of the adapter to be added.
"""
# ensure that additional adapters also add the classifier layer to modules_to_save
if hasattr(peft_config, "modules_to_save"):
classifier_module_names = ["classifier", "score"]
if peft_config.modules_to_save is None:
peft_config.modules_to_save = classifier_module_names[:]
else:
peft_config.modules_to_save.extend(classifier_module_names)

return super().add_adapter(adapter_name, peft_config)

def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -1675,19 +1710,54 @@ class PeftModelForTokenClassification(PeftModel):

def __init__(self, model: torch.nn.Module, peft_config: PeftConfig = None, adapter_name: str = "default") -> None:
super().__init__(model, peft_config, adapter_name)

classifier_module_names = ["classifier", "score"]
if self.modules_to_save is None:
self.modules_to_save = {"classifier", "score"}
self.modules_to_save = set(classifier_module_names)
else:
self.modules_to_save.update({"classifier", "score"})
self.modules_to_save.update(classifier_module_names)

if hasattr(peft_config, "modules_to_save"):
if peft_config.modules_to_save is None:
peft_config.modules_to_save = classifier_module_names[:]
else:
peft_config.modules_to_save.extend(classifier_module_names)

for name, _ in self.base_model.named_children():
if any(module_name in name for module_name in self.modules_to_save):
self.cls_layer_name = name
break

# to make sure classifier layer is trainable
# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name)

def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None:
"""
Add an adapter to the model based on the passed configuration.
This adapter is not trained. To load a trained adapter, check out [`PeftModel.load_adapter`].
The name for the new adapter should be unique.
The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active
adapter.
Args:
adapter_name (`str`):
The name of the adapter to be added.
peft_config ([`PeftConfig`]):
The configuration of the adapter to be added.
"""
# ensure that additional adapters also add the classifier layer to modules_to_save
if hasattr(peft_config, "modules_to_save"):
classifier_module_names = ["classifier", "score"]
if peft_config.modules_to_save is None:
peft_config.modules_to_save = classifier_module_names[:]
else:
peft_config.modules_to_save.extend(classifier_module_names)

return super().add_adapter(adapter_name, peft_config)

def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -1850,19 +1920,54 @@ class PeftModelForQuestionAnswering(PeftModel):

def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
super().__init__(model, peft_config, adapter_name)

qa_module_names = ["qa_outputs"]
if self.modules_to_save is None:
self.modules_to_save = {"qa_outputs"}
self.modules_to_save = set(qa_module_names)
else:
self.modules_to_save.update({"qa_outputs"})
self.modules_to_save.update(qa_module_names)

if hasattr(peft_config, "modules_to_save"):
if peft_config.modules_to_save is None:
peft_config.modules_to_save = qa_module_names[:]
else:
peft_config.modules_to_save.extend(qa_module_names)

for name, _ in self.base_model.named_children():
if any(module_name in name for module_name in self.modules_to_save):
self.cls_layer_name = name
break

# to make sure classifier layer is trainable
# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name)

def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None:
"""
Add an adapter to the model based on the passed configuration.
This adapter is not trained. To load a trained adapter, check out [`PeftModel.load_adapter`].
The name for the new adapter should be unique.
The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active
adapter.
Args:
adapter_name (`str`):
The name of the adapter to be added.
peft_config ([`PeftConfig`]):
The configuration of the adapter to be added.
"""
# ensure that additional adapters also add the classifier layer to modules_to_save
if hasattr(peft_config, "modules_to_save"):
qa_module_names = ["qa_outputs"]
if peft_config.modules_to_save is None:
peft_config.modules_to_save = qa_module_names[:]
else:
peft_config.modules_to_save.extend(qa_module_names)

return super().add_adapter(adapter_name, peft_config)

def forward(
self,
input_ids=None,
Expand Down
129 changes: 82 additions & 47 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,17 +449,85 @@ def _unload_and_optionally_merge(

return self.model

def _check_add_weighted_adapter(
self, adapters: list[str], combination_type: str, svd_rank: int | None
) -> tuple[str, int, str]:
"""
Helper function to check if the arguments to add_weighted_adapter are valid and compatible with the underlying
model.
"""
for adapter in adapters:
if adapter not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter} does not exist")

# If more than one of the adapters targets the same module with modules_to_save, raise an error, as these
# modules cannot be merged. First, find the ModulesToSaveWrapper instances in the model, then check if they
# have modules for the adapters to be merged.
modules_to_save_wrappers = [module for module in self.modules() if isinstance(module, ModulesToSaveWrapper)]
problematic_wrappers = [
wrapper
for wrapper in modules_to_save_wrappers
if sum(adapter in wrapper.modules_to_save for adapter in adapters) > 1
]
if problematic_wrappers:
raise ValueError(
"Cannot add weighted adapters if they target the same module with modules_to_save, but found "
f"{len(problematic_wrappers)} such instance(s)."
)

# if there is only one adapter, we can only use linear merging
combination_type = "linear" if len(adapters) == 1 else combination_type

adapters_ranks = [self.peft_config[adapter].r for adapter in adapters]
if combination_type in ("linear", "ties", "dare_ties", "dare_linear", "magnitude_prune"):
# all adapters ranks should be same, new rank is just this value
if len(set(adapters_ranks)) != 1:
raise ValueError(
"All adapters must have the same r value when using combination_type linear, ties, dare_ties or "
"dare_linear."
)
new_rank = adapters_ranks[0]
elif combination_type == "cat":
# adapters ranks may be different, new rank is sum of all ranks
# be careful, because output adapter rank may be really big if mixing a lot of adapters
new_rank = sum(adapters_ranks)
elif combination_type.endswith("svd"):
# new rank is the max of all ranks of the adapters if not provided
new_rank = svd_rank or max(adapters_ranks)
else:
raise ValueError(f"Invalid combination_type: {combination_type}")

target_module_types = [type(self.peft_config[adapter].target_modules) for adapter in adapters]
if not target_module_types:
raise ValueError(f"Found no adapter matching the names in {adapters}")
if len(set(target_module_types)) > 1:
raise ValueError(
"all adapter configs should follow the same target modules type. "
"Combining adapters with `target_modules` type being a mix of list/set and string is not supported."
)

if target_module_types[0] == str:
new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters)
elif target_module_types[0] == set:
new_target_modules = reduce(
operator.or_, (self.peft_config[adapter].target_modules for adapter in adapters)
)
else:
raise TypeError(f"Invalid type {target_module_types[0]} found in target_modules")

return combination_type, new_rank, new_target_modules

def add_weighted_adapter(
self,
adapters,
weights,
adapter_name,
combination_type="svd",
svd_rank=None,
svd_clamp=None,
svd_full_matrices=True,
svd_driver=None,
density=None,
adapters: list[str],
weights: list[float],
adapter_name: str,
combination_type: str = "svd",
svd_rank: int | None = None,
svd_clamp: int | None = None,
svd_full_matrices: bool = True,
svd_driver: str | None = None,
density: float | None = None,
majority_sign_method: Literal["total", "frequency"] = "total",
) -> None:
"""
Expand Down Expand Up @@ -508,44 +576,11 @@ def add_weighted_adapter(
if adapter not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter} does not exist")

# if there is only one adapter, we can only use linear merging
combination_type = "linear" if len(adapters) == 1 else combination_type

adapters_ranks = [self.peft_config[adapter].r for adapter in adapters]
if combination_type in ("linear", "ties", "dare_ties", "dare_linear", "magnitude_prune"):
# all adapters ranks should be same, new rank is just this value
if len(set(adapters_ranks)) != 1:
raise ValueError(
"All adapters must have the same r value when using combination_type linear, ties, dare_ties or dare_linear."
)
new_rank = adapters_ranks[0]
elif combination_type == "cat":
# adapters ranks may be different, new rank is sum of all ranks
# be careful, because output adapter rank may be really big if mixing a lot of adapters
new_rank = sum(adapters_ranks)
elif combination_type.endswith("svd"):
# new rank is the max of all ranks of the adapters if not provided
new_rank = svd_rank or max(adapters_ranks)
else:
raise ValueError(f"Invalid combination_type: {combination_type}")

target_module_types = [type(self.peft_config[adapter].target_modules) for adapter in adapters]
if not target_module_types:
raise ValueError(f"Found no adapter matching the names in {adapters}")
if len(set(target_module_types)) > 1:
raise ValueError(
"all adapter configs should follow the same target modules type. "
"Combining adapters with `target_modules` type being a mix of list/set and string is not supported."
)

if target_module_types[0] == str:
new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters)
elif target_module_types[0] == set:
new_target_modules = reduce(
operator.or_, (self.peft_config[adapter].target_modules for adapter in adapters)
)
else:
raise TypeError(f"Invalid type {target_module_types[0]} found in target_modules")
combination_type, new_rank, new_target_modules = self._check_add_weighted_adapter(
adapters=adapters,
combination_type=combination_type,
svd_rank=svd_rank,
)

self.peft_config[adapter_name] = replace(
self.peft_config[adapters[0]],
Expand Down
8 changes: 7 additions & 1 deletion src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,13 @@ def check_adapter_name(adapter_name):
if isinstance(module, ModulesToSaveWrapper):
# only check the adapter_name if we actually encounter a ModulesToSaveWrapper, otherwise we don't care
adapter_name = check_adapter_name(adapter_name)
module.set_adapter(adapter_name)

# if the adapter is found in this module, set it as the active adapter, else disable the adapters of this
# module
if adapter_name in module.modules_to_save:
module.set_adapter(adapter_name)
else:
module.enable_adapters(False)


def _prepare_prompt_learning_config(peft_config, model_config):
Expand Down
Loading

0 comments on commit 0d283ae

Please sign in to comment.