Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX Multiple adapters and modules_to_save #1615

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 116 additions & 11 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def set_additional_trainable_modules(self, peft_config, adapter_name):
self.modules_to_save = set(peft_config.modules_to_save)
else:
self.modules_to_save.update(peft_config.modules_to_save)
_set_trainable(self, adapter_name)
_set_trainable(self, adapter_name) # this may add a new ModulesToSaveWrapper

@classmethod
def _split_kwargs(cls, kwargs: dict[str, Any]):
Expand All @@ -714,7 +714,7 @@ def _split_kwargs(cls, kwargs: dict[str, Any]):

return hf_hub_download_kwargs, other_kwargs

def _update_offload(self, offload_index: dict[dict[str:str]], adapters_weights: dict[str : torch.tensor]):
def _update_offload(self, offload_index: dict[str, dict[str, str]], adapters_weights: dict[str, torch.tensor]):
"""
Update the offload_index and safetensors files for loading and mergine PeftModels with disk-offloaded modules.

Expand Down Expand Up @@ -1020,19 +1020,54 @@ class PeftModelForSequenceClassification(PeftModel):

def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
super().__init__(model, peft_config, adapter_name)

classifier_module_names = ["classifier", "score"]
if self.modules_to_save is None:
self.modules_to_save = {"classifier", "score"}
self.modules_to_save = set(classifier_module_names)
else:
self.modules_to_save.update({"classifier", "score"})
self.modules_to_save.update(classifier_module_names)

if hasattr(peft_config, "modules_to_save"):
if peft_config.modules_to_save is None:
peft_config.modules_to_save = classifier_module_names[:]
else:
peft_config.modules_to_save.extend(classifier_module_names)

for name, _ in self.base_model.named_children():
if any(module_name in name for module_name in self.modules_to_save):
self.cls_layer_name = name
break

# to make sure classifier layer is trainable
# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name)

def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None:
"""
Add an adapter to the model based on the passed configuration.

This adapter is not trained. To load a trained adapter, check out [`PeftModel.load_adapter`].

The name for the new adapter should be unique.

The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active
adapter.

Args:
adapter_name (`str`):
The name of the adapter to be added.
peft_config ([`PeftConfig`]):
The configuration of the adapter to be added.
"""
# ensure that additional adapters also add the classifier layer to modules_to_save
if hasattr(peft_config, "modules_to_save"):
classifier_module_names = ["classifier", "score"]
if peft_config.modules_to_save is None:
peft_config.modules_to_save = classifier_module_names[:]
else:
peft_config.modules_to_save.extend(classifier_module_names)

return super().add_adapter(adapter_name, peft_config)

def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -1672,19 +1707,54 @@ class PeftModelForTokenClassification(PeftModel):

def __init__(self, model: torch.nn.Module, peft_config: PeftConfig = None, adapter_name: str = "default") -> None:
super().__init__(model, peft_config, adapter_name)

classifier_module_names = ["classifier", "score"]
if self.modules_to_save is None:
self.modules_to_save = {"classifier", "score"}
self.modules_to_save = set(classifier_module_names)
else:
self.modules_to_save.update({"classifier", "score"})
self.modules_to_save.update(classifier_module_names)

if hasattr(peft_config, "modules_to_save"):
if peft_config.modules_to_save is None:
peft_config.modules_to_save = classifier_module_names[:]
else:
peft_config.modules_to_save.extend(classifier_module_names)

for name, _ in self.base_model.named_children():
if any(module_name in name for module_name in self.modules_to_save):
self.cls_layer_name = name
break

# to make sure classifier layer is trainable
# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name)

def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None:
"""
Add an adapter to the model based on the passed configuration.

This adapter is not trained. To load a trained adapter, check out [`PeftModel.load_adapter`].

The name for the new adapter should be unique.

The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active
adapter.

Args:
adapter_name (`str`):
The name of the adapter to be added.
peft_config ([`PeftConfig`]):
The configuration of the adapter to be added.
"""
# ensure that additional adapters also add the classifier layer to modules_to_save
if hasattr(peft_config, "modules_to_save"):
classifier_module_names = ["classifier", "score"]
if peft_config.modules_to_save is None:
peft_config.modules_to_save = classifier_module_names[:]
else:
peft_config.modules_to_save.extend(classifier_module_names)

return super().add_adapter(adapter_name, peft_config)

def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -1847,19 +1917,54 @@ class PeftModelForQuestionAnswering(PeftModel):

def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
super().__init__(model, peft_config, adapter_name)

qa_module_names = ["qa_outputs"]
if self.modules_to_save is None:
self.modules_to_save = {"qa_outputs"}
self.modules_to_save = set(qa_module_names)
else:
self.modules_to_save.update({"qa_outputs"})
self.modules_to_save.update(qa_module_names)

if hasattr(peft_config, "modules_to_save"):
if peft_config.modules_to_save is None:
peft_config.modules_to_save = qa_module_names[:]
else:
peft_config.modules_to_save.extend(qa_module_names)

for name, _ in self.base_model.named_children():
if any(module_name in name for module_name in self.modules_to_save):
self.cls_layer_name = name
break

# to make sure classifier layer is trainable
# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name)

def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None:
"""
Add an adapter to the model based on the passed configuration.

This adapter is not trained. To load a trained adapter, check out [`PeftModel.load_adapter`].

The name for the new adapter should be unique.

The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active
adapter.

Args:
adapter_name (`str`):
The name of the adapter to be added.
peft_config ([`PeftConfig`]):
The configuration of the adapter to be added.
"""
# ensure that additional adapters also add the classifier layer to modules_to_save
if hasattr(peft_config, "modules_to_save"):
qa_module_names = ["qa_outputs"]
if peft_config.modules_to_save is None:
peft_config.modules_to_save = qa_module_names[:]
else:
peft_config.modules_to_save.extend(qa_module_names)

return super().add_adapter(adapter_name, peft_config)

def forward(
self,
input_ids=None,
Expand Down
129 changes: 82 additions & 47 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,17 +449,85 @@ def _unload_and_optionally_merge(

return self.model

def _check_add_weighted_adapter(
self, adapters: list[str], combination_type: str, svd_rank: int | None
) -> tuple[str, int, str]:
"""
Helper function to check if the arguments to add_weighted_adapter are valid and compatible with the underlying
model.
"""
for adapter in adapters:
if adapter not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter} does not exist")

# If more than one of the adapters targets the same module with modules_to_save, raise an error, as these
# modules cannot be merged. First, find the ModulesToSaveWrapper instances in the model, then check if they
# have modules for the adapters to be merged.
modules_to_save_wrappers = [module for module in self.modules() if isinstance(module, ModulesToSaveWrapper)]
problematic_wrappers = [
wrapper
for wrapper in modules_to_save_wrappers
if sum(adapter in wrapper.modules_to_save for adapter in adapters) > 1
]
if problematic_wrappers:
raise ValueError(
"Cannot add weighted adapters if they target the same module with modules_to_save, but found "
f"{len(problematic_wrappers)} such instance(s)."
)

# if there is only one adapter, we can only use linear merging
combination_type = "linear" if len(adapters) == 1 else combination_type

adapters_ranks = [self.peft_config[adapter].r for adapter in adapters]
if combination_type in ("linear", "ties", "dare_ties", "dare_linear", "magnitude_prune"):
# all adapters ranks should be same, new rank is just this value
if len(set(adapters_ranks)) != 1:
raise ValueError(
"All adapters must have the same r value when using combination_type linear, ties, dare_ties or "
"dare_linear."
)
new_rank = adapters_ranks[0]
elif combination_type == "cat":
# adapters ranks may be different, new rank is sum of all ranks
# be careful, because output adapter rank may be really big if mixing a lot of adapters
new_rank = sum(adapters_ranks)
elif combination_type.endswith("svd"):
# new rank is the max of all ranks of the adapters if not provided
new_rank = svd_rank or max(adapters_ranks)
else:
raise ValueError(f"Invalid combination_type: {combination_type}")

target_module_types = [type(self.peft_config[adapter].target_modules) for adapter in adapters]
if not target_module_types:
raise ValueError(f"Found no adapter matching the names in {adapters}")
if len(set(target_module_types)) > 1:
raise ValueError(
"all adapter configs should follow the same target modules type. "
"Combining adapters with `target_modules` type being a mix of list/set and string is not supported."
)

if target_module_types[0] == str:
new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters)
elif target_module_types[0] == set:
new_target_modules = reduce(
operator.or_, (self.peft_config[adapter].target_modules for adapter in adapters)
)
else:
raise TypeError(f"Invalid type {target_module_types[0]} found in target_modules")

return combination_type, new_rank, new_target_modules

def add_weighted_adapter(
self,
adapters,
weights,
adapter_name,
combination_type="svd",
svd_rank=None,
svd_clamp=None,
svd_full_matrices=True,
svd_driver=None,
density=None,
adapters: list[str],
weights: list[float],
adapter_name: str,
combination_type: str = "svd",
svd_rank: int | None = None,
svd_clamp: int | None = None,
svd_full_matrices: bool = True,
svd_driver: str | None = None,
density: float | None = None,
majority_sign_method: Literal["total", "frequency"] = "total",
) -> None:
"""
Expand Down Expand Up @@ -508,44 +576,11 @@ def add_weighted_adapter(
if adapter not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter} does not exist")

# if there is only one adapter, we can only use linear merging
combination_type = "linear" if len(adapters) == 1 else combination_type

adapters_ranks = [self.peft_config[adapter].r for adapter in adapters]
if combination_type in ("linear", "ties", "dare_ties", "dare_linear", "magnitude_prune"):
# all adapters ranks should be same, new rank is just this value
if len(set(adapters_ranks)) != 1:
raise ValueError(
"All adapters must have the same r value when using combination_type linear, ties, dare_ties or dare_linear."
)
new_rank = adapters_ranks[0]
elif combination_type == "cat":
# adapters ranks may be different, new rank is sum of all ranks
# be careful, because output adapter rank may be really big if mixing a lot of adapters
new_rank = sum(adapters_ranks)
elif combination_type.endswith("svd"):
# new rank is the max of all ranks of the adapters if not provided
new_rank = svd_rank or max(adapters_ranks)
else:
raise ValueError(f"Invalid combination_type: {combination_type}")

target_module_types = [type(self.peft_config[adapter].target_modules) for adapter in adapters]
if not target_module_types:
raise ValueError(f"Found no adapter matching the names in {adapters}")
if len(set(target_module_types)) > 1:
raise ValueError(
"all adapter configs should follow the same target modules type. "
"Combining adapters with `target_modules` type being a mix of list/set and string is not supported."
)

if target_module_types[0] == str:
new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters)
elif target_module_types[0] == set:
new_target_modules = reduce(
operator.or_, (self.peft_config[adapter].target_modules for adapter in adapters)
)
else:
raise TypeError(f"Invalid type {target_module_types[0]} found in target_modules")
combination_type, new_rank, new_target_modules = self._check_add_weighted_adapter(
adapters=adapters,
combination_type=combination_type,
svd_rank=svd_rank,
)

self.peft_config[adapter_name] = replace(
self.peft_config[adapters[0]],
Expand Down
8 changes: 7 additions & 1 deletion src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,13 @@ def check_adapter_name(adapter_name):
if isinstance(module, ModulesToSaveWrapper):
# only check the adapter_name if we actually encounter a ModulesToSaveWrapper, otherwise we don't care
adapter_name = check_adapter_name(adapter_name)
module.set_adapter(adapter_name)

# if the adapter is found in this module, set it as the active adapter, else disable the adapters of this
# module
if adapter_name in module.modules_to_save:
module.set_adapter(adapter_name)
else:
module.enable_adapters(False)


def _prepare_prompt_learning_config(peft_config, model_config):
Expand Down
Loading
Loading