Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix documentation and consistency issues for AdapterFusion methods #259

Merged
merged 7 commits into from
Dec 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions adapter_docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Currently, we support the PyTorch versions of all models listed in the *Supporte
classes/models/gpt2
classes/models/mbart
classes/models/roberta
classes/models/t5
classes/models/xlmroberta


Expand Down
123 changes: 58 additions & 65 deletions src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def add_fusion(self, adapter_names: Union[Fuse, list], adapter_fusion_config=Non

def add_adapter_fusion(
self,
adapter_names: Union[Fuse, list],
adapter_names: Union[Fuse, list, str],
config=None,
overwrite_ok: bool = False,
set_active: bool = False,
Expand All @@ -248,18 +248,24 @@ def add_adapter_fusion(
Adds AdapterFusion to the model with alll the necessary configurations and weight initializations

Args:
adapter_names: a list of adapter names which should be fused
adapter_fusion_config (str or dict): adapter fusion configuration, can be either:
adapter_names (Fuse or list or str): AdapterFusion layer to add. Can be either:

- a ``Fuse`` composition block
- a list of adapter names to fuse
- a comma-separated string of adapter names to fuse
config (str or dict): adapter fusion configuration, can be either:

- a string identifying a pre-defined adapter fusion configuration
- a dictionary representing the adapter fusion configuration
- the path to a file containing the adapter fusion configuration
override_kwargs: dictionary items for values which should be overwritten in the default AdapterFusion configuration
overwrite_ok (bool, optional): Overwrite an AdapterFusion layer with the same name if it exists. By default (False), an exception is thrown.
set_active (bool, optional): Activate the added AdapterFusion. By default (False), the AdapterFusion is added but not activated.
"""
# TODO-V2 Allow nested items or directly pass Fuse block?
if isinstance(adapter_names, Fuse):
adapter_names = adapter_names.children
elif isinstance(adapter_names, str):
adapter_names = adapter_names.split(",")

if isinstance(config, dict):
config = AdapterFusionConfig.from_dict(config) # ensure config is ok and up-to-date
# In case adapter already exists and we allow overwriting, explicitly delete the existing one first
Expand Down Expand Up @@ -288,19 +294,21 @@ def delete_adapter(self, adapter_name: str):
if self.active_adapters == Stack(adapter_name):
self.active_adapters = None

def delete_adapter_fusion(self, adapter_names: Union[Fuse, list]):
def delete_adapter_fusion(self, adapter_names: Union[Fuse, list, str]):
"""
Deletes the AdapterFusion layer of the specified adapters.

Args:
adapter_names (Union[Fuse, list]): List of adapters for which to delete the AdapterFusion layer.
adapter_names (Union[Fuse, list, str]): AdapterFusion layer to delete.
"""
if isinstance(adapter_names, Fuse):
adapter_fusion_name = ",".join(adapter_names.children)
elif isinstance(adapter_names, list):
adapter_fusion_name = ",".join(adapter_names)
else:
elif isinstance(adapter_names, str):
adapter_fusion_name = adapter_names
else:
raise ValueError("Invalid AdapterFusion definition: {}".format(adapter_names))

if adapter_fusion_name not in self.config.adapters.fusions:
logger.info("No AdapterFusion '%s' found for deletion. Skipping.", adapter_fusion_name)
Expand Down Expand Up @@ -339,28 +347,36 @@ def save_adapter(
def save_adapter_fusion(
self,
save_directory: str,
adapter_names: list,
adapter_names: Union[Fuse, list, str],
meta_dict: dict = None,
custom_weights_loaders: Optional[List[WeightsLoader]] = None,
):
"""
Saves an adapter and its configuration file to a directory so that it can be shared or reloaded using
`load_adapter()`.
Saves an AdapterFusion layer and its configuration file to a directory so that it can be shared or reloaded
using `load_adapter_fusion()`.

Args:
save_directory (str): Path to a directory where the adapter should be saved.
adapter_name (str): Name of the adapter to be saved.
save_directory (str): Path to a directory where the AdapterFusion should be saved.
adapter_names (Union[Fuse, list, str]): AdapterFusion to be saved.

Raises:
ValueError: If the given adapter name is invalid.
ValueError: If the given AdapterFusion name is invalid.
"""
if isinstance(adapter_names, Fuse):
adapter_fusion_name = ",".join(adapter_names.children)
elif isinstance(adapter_names, list):
adapter_fusion_name = ",".join(adapter_names)
elif isinstance(adapter_names, str):
adapter_fusion_name = adapter_names
else:
raise ValueError("Invalid AdapterFusion definition: {}".format(adapter_names))

loader = AdapterFusionLoader(self)
loader.save(save_directory, adapter_names, meta_dict)
loader.save(save_directory, adapter_fusion_name, meta_dict)
# save additional custom weights
if custom_weights_loaders:
for weights_loader in custom_weights_loaders:
weights_loader.save(save_directory, adapter_names)
weights_loader.save(save_directory, adapter_fusion_name)

def load_adapter(
self,
Expand Down Expand Up @@ -437,24 +453,16 @@ def load_adapter_fusion(
**kwargs
) -> str:
"""
Loads a pre-trained pytorch adapter module from the local file system or a remote location.
Loads a pre-trained AdapterFusion layer from the local file system.

Args:
adapter_fusion_name_or_path (str): can be either:

- the identifier of a pre-trained task adapter fusion module to be loaded from Adapter Hub
- a path to a directory containing adapter weights saved using `model.saved_adapter()`
- a URL pointing to a zip folder containing a saved adapter module
config (dict or str, optional): The requested configuration of the adapter fusion.
If not specified, will be either: - the default adapter config for the requested adapter fusion if
specified - the global default adapter fusion config
model_name (str, optional): The string identifier of the pre-trained model.
load_as (str, optional): Load the adapter using this name. By default, the name with which the adapter was
saved will be used.
adapter_fusion_name_or_path (str): a path to a directory containing AdapterFusion weights saved using `model.save_adapter_fusion()`.
load_as (str, optional): Load the AdapterFusion using this name.
By default, the name with which the AdapterFusion layer was saved will be used.
set_active (bool, optional): Activate the loaded AdapterFusion. By default (False), the AdapterFusion is loaded but not activated.

Returns:
str: The name with which the adapter was added to the model.
str: The name with which the AdapterFusion was added to the model.
"""

loader = AdapterFusionLoader(self)
Expand Down Expand Up @@ -500,10 +508,11 @@ def save_all_adapter_fusions(
custom_weights_loaders: Optional[List[WeightsLoader]] = None,
):
"""
Saves all adapters of this model together with their configuration to subfolders of the given location.
Saves all AdapterFusion layers of this model together with their configuration to subfolders of the given
location.

Args:
save_directory (str): Path to a directory where the adapters should be saved.
save_directory (str): Path to a directory where the AdapterFusion layers should be saved.
"""
for name in self.config.adapters.fusions:
adapter_fusion_config = self.config.adapters.get_fusion(name)
Expand Down Expand Up @@ -813,29 +822,36 @@ def save_all_adapters(
def save_adapter_fusion(
self,
save_directory: str,
adapter_names: list,
adapter_names: Union[Fuse, list, str],
meta_dict: dict = None,
custom_weights_loaders: Optional[List[WeightsLoader]] = None,
with_head: Optional[str] = False,
with_head: Union[bool, str] = False,
):
"""
Saves an adapter and its configuration file to a directory so that it can be shared or reloaded using
`load_adapter()`.
Saves an AdapterFusion layer and its configuration file to a directory so that it can be shared or reloaded
using `load_adapter_fusion()`.

Args:
save_directory (str): Path to a directory where the adapter should be saved.
adapter_names (list): Name of the adapter to be saved.
with_head (str): The name of the head that should be saved with the adapter fusion, if this is True the
head with the same name as the adapter fusion is used
save_directory (str): Path to a directory where the AdapterFusion should be saved.
adapter_names (Union[Fuse, list, str]): AdapterFusion to be saved.
with_head (Union[bool, str]): If True, will save a head with the same name as the AdapterFusionLayer. If a string,
this will be used as the name of the head to be saved.

Raises:
ValueError: If the given adapter name is invalid.
ValueError: If the given AdapterFusion name is invalid.
"""

super().save_adapter_fusion(save_directory, adapter_names, meta_dict, custom_weights_loaders)

if with_head:
head_name = with_head if isinstance(with_head, str) else adapter_names
# Make sure to cover the different options for adapter_names
if isinstance(with_head, str):
head_name = with_head
elif isinstance(adapter_names, Fuse):
head_name = adapter_names.name
elif isinstance(adapter_names, list):
head_name = ",".join(adapter_names)
else:
head_name = adapter_names
if head_name not in self.heads:
raise ValueError("No head with name {} found".format(head_name))
loader = PredictionHeadLoader(self)
Expand All @@ -850,29 +866,6 @@ def load_adapter_fusion(
with_head: bool = True,
**kwargs
) -> str:
"""
Loads a pre-trained pytorch adapter module from the local file system or a remote location.

Args:
custom_weights_loaders: custom weight loaders that should be used
with_head: whether the head should be saved with the model. This can specify a head name or True if tha adapter has
the same name as the fusion
adapter_fusion_name_or_path (str): can be either:

- the identifier of a pre-trained task adapter fusion module to be loaded from Adapter Hub
- a path to a directory containing adapter weights saved using `model.saved_adapter()`
- a URL pointing to a zip folder containing a saved adapter module
config (dict or str, optional): The requested configuration of the adapter fusion.
If not specified, will be either: - the default adapter config for the requested adapter fusion if
specified - the global default adapter fusion config
model_name (str, optional): The string identifier of the pre-trained model.
load_as (str, optional): Load the adapter using this name. By default, the name with which the adapter was
saved will be used.
set_active (bool, optional): Activate the loaded AdapterFusion. By default (False), the AdapterFusion is loaded but not activated.

Returns:
str: The name with which the adapter was added to the model.
"""
if with_head:
if custom_weights_loaders is None:
custom_weights_loaders = []
Expand Down
6 changes: 6 additions & 0 deletions tests/test_adapter_fusion_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import os
import tempfile
from dataclasses import asdict

Expand All @@ -12,6 +13,7 @@
PfeifferConfig,
)
from transformers.adapters.composition import Fuse
from transformers.adapters.utils import ADAPTERFUSION_WEIGHTS_NAME
from transformers.testing_utils import require_torch, torch_device


Expand Down Expand Up @@ -102,6 +104,10 @@ def test_load_adapter_fusion(self):
model1.save_adapter_fusion(temp_dir, ",".join([name1, name2]))
# also tests that set_active works
model2.load_adapter_fusion(temp_dir, set_active=True)
# In another directory, also check that saving via passing a Fuse block works
with tempfile.TemporaryDirectory() as temp_dir:
model1.save_adapter_fusion(temp_dir, Fuse(name1, name2))
self.assertTrue(os.path.exists(os.path.join(temp_dir, ADAPTERFUSION_WEIGHTS_NAME)))

# check if adapter was correctly loaded
self.assertEqual(model1.config.adapters.fusions.keys(), model2.config.adapters.fusions.keys())
Expand Down