diff --git a/src/transformers/adapters/configuration.py b/src/transformers/adapters/configuration.py index c25d89e6d2..3bb5e5e599 100644 --- a/src/transformers/adapters/configuration.py +++ b/src/transformers/adapters/configuration.py @@ -195,8 +195,12 @@ def __init__(self, **kwargs): adapters_list = dict( map(lambda t: (t[0], t[1][1] or t[1][0] if isinstance(t[1], tuple) else t[1]), adapters_list.items()) ) - self.adapters: Mapping[str] = adapters_list + self.adapters: Mapping[str, str] = adapters_list self.config_map = kwargs.pop("config_map", {}) + + self.fusions: Mapping[str, str] = kwargs.pop("fusions", {}) + self.fusion_config_map = kwargs.pop("fusion_config_map", {}) + # TODO-V2 Save this with config? self.active_setup: Optional[AdapterCompositionBlock] = None self.skip_layers = None @@ -212,7 +216,7 @@ def __iter__(self): def __len__(self): return len(self.adapters) - def get(self, adapter_name: str): + def get(self, adapter_name: str) -> Optional[dict]: """ Gets the config dictionary for a given adapter. @@ -248,7 +252,7 @@ def add(self, adapter_name: str, config: Optional[Union[str, dict]] = None): config = DEFAULT_ADAPTER_CONFIG if isinstance(config, str): if config not in ADAPTER_CONFIG_MAP and config not in self.config_map: - raise ValueError(f"Invalid adapter config identifier '{config}''") + raise ValueError(f"Invalid adapter config identifier '{config}'.") config_name = config # if it's a dict, compute it's hash and add a new entry to the config map elif isinstance(config, Mapping): @@ -259,6 +263,55 @@ def add(self, adapter_name: str, config: Optional[Union[str, dict]] = None): self.adapters[adapter_name] = config_name logger.info(f"Adding adapter '{adapter_name}'.") + def get_fusion(self, fusion_name: Union[str, List[str]]) -> Optional[dict]: + """ + Gets the config dictionary for a given AdapterFusion. + + Args: + fusion_name (Union[str, List[str]]): The name of the AdapterFusion or the adapters to fuse. + + Returns: + Optional[dict]: The AdapterFusion configuration. + """ + if isinstance(fusion_name, list): + fusion_name = ",".join(fusion_name) + if fusion_name in self.fusions: + config_name = self.fusions[fusion_name] + if config_name in self.fusion_config_map: + config = self.fusion_config_map.get(config_name, None) + else: + config = ADAPTERFUSION_CONFIG_MAP.get(config_name, None) + else: + config = None + return config + + def add_fusion(self, fusion_name: Union[str, List[str]], config: Optional[Union[str, dict]] = None): + """ + Adds a new AdapterFusion. + + Args: + fusion_name (Union[str, List[str]]): The name of the AdapterFusion or the adapters to fuse. + config (Optional[Union[str, dict]], optional): AdapterFusion config. Defaults to None. + """ + if isinstance(fusion_name, list): + fusion_name = ",".join(fusion_name) + if fusion_name in self.fusions: + raise ValueError(f"An AdapterFusion with the name '{fusion_name}' has already been added.") + if config is None: + config = DEFAULT_ADAPTERFUSION_CONFIG + if isinstance(config, str): + if config not in ADAPTERFUSION_CONFIG_MAP and config not in self.fusion_config_map: + raise ValueError(f"Invalid AdapterFusion config identifier '{config}'.") + config_name = config + # if it's a dict, compute it's hash and add a new entry to the config map + elif isinstance(config, Mapping): + config_name = get_adapter_config_hash(config) + self.fusion_config_map[config_name] = config + else: + raise ValueError("Invalid AdapterFusion config: {}".format(config)) + self.fusions[fusion_name] = config_name + logger.info(f"Adding AdapterFusion '{fusion_name}'.") + def common_config_value(self, adapter_names: list, attribute: str): """ Checks whether all adapters in a list share the same config setting for a given attribute and returns the @@ -285,6 +338,8 @@ def to_dict(self): output_dict = {} output_dict["adapters"] = copy.deepcopy(self.adapters) output_dict["config_map"] = copy.deepcopy(self.config_map) + output_dict["fusions"] = copy.deepcopy(self.fusions) + output_dict["fusion_config_map"] = copy.deepcopy(self.fusion_config_map) return output_dict diff --git a/src/transformers/adapters/layer.py b/src/transformers/adapters/layer.py index 4dc5901bb1..e084fb6287 100644 --- a/src/transformers/adapters/layer.py +++ b/src/transformers/adapters/layer.py @@ -75,7 +75,12 @@ def add_fusion_layer(self, adapter_names: Union[List, str]): """See BertModel.add_fusion_layer""" adapter_names = adapter_names if isinstance(adapter_names, list) else adapter_names.split(",") if self.config.adapters.common_config_value(adapter_names, self.adapter_config_key): - fusion = BertFusion(self.config) + fusion_config = self.config.adapters.get_fusion(adapter_names) + fusion = BertFusion( + fusion_config, + self.config.hidden_size, + self.config.attention_probs_dropout_prob, + ) fusion.train(self.training) # make sure training mode is consistent self.adapter_fusion_layer[",".join(adapter_names)] = fusion @@ -114,6 +119,7 @@ def get_adapter_preparams( adapter_config, hidden_states, input_tensor, + fusion_config=None, ): """ Retrieves the hidden_states, query (for Fusion), and residual connection according to the set configuratio @@ -131,7 +137,7 @@ def get_adapter_preparams( if adapter_config["residual_before_ln"]: residual = hidden_states - if hasattr(self.config, "adapter_fusion") and self.config.adapter_fusion["query_before_ln"]: + if fusion_config is not None and fusion_config["query_before_ln"]: query = hidden_states if adapter_config["original_ln_before"]: @@ -143,7 +149,7 @@ def get_adapter_preparams( if not adapter_config["residual_before_ln"]: residual = hidden_states - if hasattr(self.config, "adapter_fusion") and not self.config.adapter_fusion["query_before_ln"]: + if fusion_config is not None and not fusion_config["query_before_ln"]: query = hidden_states return hidden_states, query, residual @@ -196,7 +202,10 @@ def adapter_fusion(self, adapter_setup: Fuse, hidden_states, input_tensor, lvl=0 """ # config of _last_ fused adapter is significant adapter_config = self.config.adapters.get(adapter_setup.last()) - hidden_states, query, residual = self.get_adapter_preparams(adapter_config, hidden_states, input_tensor) + fusion_config = self.config.adapters.get_fusion(adapter_setup.name) + hidden_states, query, residual = self.get_adapter_preparams( + adapter_config, hidden_states, input_tensor, fusion_config=fusion_config + ) up_list = [] @@ -239,7 +248,7 @@ def adapter_split(self, adapter_setup: Split, hidden_states, input_tensor, lvl=0 """ # config of _first_ of splitted adapters is significant adapter_config = self.config.adapters.get(adapter_setup.first()) - hidden_states, query, residual = self.get_adapter_preparams(adapter_config, hidden_states, input_tensor) + hidden_states, _, residual = self.get_adapter_preparams(adapter_config, hidden_states, input_tensor) # split hidden representations and residuals at split index split_hidden_states = [ diff --git a/src/transformers/adapters/loading.py b/src/transformers/adapters/loading.py index b15d587299..e827957d56 100644 --- a/src/transformers/adapters/loading.py +++ b/src/transformers/adapters/loading.py @@ -461,7 +461,7 @@ def rename_func(self, old_name, new_name): "adapter_fusion_layer.{}".format(old_name), "adapter_fusion_layer.{}".format(new_name) ) - def save(self, save_directory: str, name: str): + def save(self, save_directory: str, name: str, meta_dict=None): """ Saves a AdapterFusion module into the given directory. @@ -470,20 +470,19 @@ def save(self, save_directory: str, name: str): name (str, optional): The AdapterFusion name. """ - if hasattr(self.model.config, "adapter_fusion_models"): - if name not in self.model.config.adapter_fusion_models: - if self.error_on_missing: - raise ValueError(f"Unknown AdapterFusion '{name}'.") - else: - logger.debug(f"No AdapterFusion with name '{name}' available.") - return + if name not in self.model.config.adapters.fusions: + if self.error_on_missing: + raise ValueError(f"Unknown AdapterFusion '{name}'.") + else: + logger.debug(f"No AdapterFusion with name '{name}' available.") + return if not exists(save_directory): mkdir(save_directory) else: assert isdir(save_directory), "Saving path should be a directory where the head can be saved." - adapter_fusion_config = self.model.config.adapter_fusion + adapter_fusion_config = self.model.config.adapters.get_fusion(name) # Save the adapter fusion configuration config_dict = build_full_config( @@ -493,7 +492,7 @@ def save(self, save_directory: str, name: str): model_name=self.model.model_name, model_class=self.model.__class__.__name__, ) - self.weights_helper.save_weights_config(save_directory, config_dict) + self.weights_helper.save_weights_config(save_directory, config_dict, meta_dict=meta_dict) # Save head weights filter_func = self.filter_func(name) @@ -519,13 +518,13 @@ def load(self, save_directory, load_as=None, loading_info=None, **kwargs): return None, None config = self.weights_helper.load_weights_config(save_directory) - if not hasattr(self.model.config, "adapter_fusion_models"): - self.model.config.adapter_fusion_models = [] adapter_fusion_name = load_as or config["name"] - if adapter_fusion_name in self.model.config.adapter_fusion_models: + if adapter_fusion_name in self.model.config.adapters.fusions: logger.warning("Overwriting existing adapter fusion module '{}'".format(adapter_fusion_name)) - self.model.add_adapter_fusion(adapter_fusion_name, config["config"], set_active=kwargs.pop("set_active", True)) + self.model.add_adapter_fusion( + adapter_fusion_name, config["config"], overwrite_ok=True, set_active=kwargs.pop("set_active", True) + ) # Load AdapterFusion weights filter_func = self.filter_func(adapter_fusion_name) diff --git a/src/transformers/adapters/model_mixin.py b/src/transformers/adapters/model_mixin.py index a790a67190..6179b901ae 100644 --- a/src/transformers/adapters/model_mixin.py +++ b/src/transformers/adapters/model_mixin.py @@ -2,19 +2,12 @@ import warnings from abc import ABC, abstractmethod from os.path import join -from typing import List, Mapping, Optional, Union +from typing import List, Optional, Union from torch import nn from .composition import AdapterCompositionBlock, Fuse, Stack, parse_composition -from .configuration import ( - ADAPTERFUSION_CONFIG_MAP, - DEFAULT_ADAPTERFUSION_CONFIG, - AdapterConfig, - AdapterFusionConfig, - ModelAdaptersConfig, - get_adapter_config_hash, -) +from .configuration import AdapterConfig, AdapterFusionConfig, ModelAdaptersConfig, get_adapter_config_hash from .hub_mixin import PushAdapterToHubMixin from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader from .modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock @@ -105,6 +98,11 @@ def __init__(self, *args, **kwargs): self.adapters = ModelAdaptersConfig(**adapter_config_dict) else: self.adapters = ModelAdaptersConfig() + # Convert AdapterFusions from old format for backwards compatibility + fusion_models = kwargs.pop("adapter_fusion_models", []) + fusion_config = kwargs.pop("adapter_fusion", None) + for fusion_adapter_names in fusion_models: + self.adapters.add_fusion(fusion_adapter_names, config=fusion_config) class ModelAdaptersMixin(PushAdapterToHubMixin, ABC): @@ -129,9 +127,8 @@ def _init_adapter_modules(self): for adapter_name in self.config.adapters: self._add_adapter(adapter_name) # Initialize fusion from config - if hasattr(self.config, "adapter_fusion_models"): - for fusion_adapter_names in self.config.adapter_fusion_models: - self._add_fusion_layer(fusion_adapter_names) + for fusion_name in self.config.adapters.fusions: + self._add_fusion_layer(fusion_name) # These methods have to be implemented by every deriving class: @@ -200,26 +197,6 @@ def set_active_adapters( self.config.adapters.active_setup = adapter_setup self.config.adapters.skip_layers = skip_layers - def set_adapter_fusion_config(self, adapter_fusion_config, override_kwargs=None): - """ - Sets the adapter fusion configuration. - - Args: - adapter_fusion_config (str or dict): adapter fusion configuration, can be either: - - - a string identifying a pre-defined adapter fusion configuration - - a dictionary representing the adapter fusion configuration - - the path to a file containing the adapter fusion configuration - """ - if override_kwargs is None: - override_kwargs = {} - if isinstance(adapter_fusion_config, str) and adapter_fusion_config in ADAPTERFUSION_CONFIG_MAP: - self.config.adapter_fusion = AdapterFusionConfig.load(adapter_fusion_config, **override_kwargs) - elif isinstance(adapter_fusion_config, Mapping): - self.config.adapter_fusion = adapter_fusion_config - else: - raise ValueError("Invalid adapter type {}".format(adapter_fusion_config)) - def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False, set_active: bool = False): """ Adds a new adapter module of the specified type to the model. @@ -249,13 +226,14 @@ def add_fusion(self, adapter_names: Union[Fuse, list], adapter_fusion_config=Non "add_fusion() has been deprecated in favor of add_adapter_fusion(). Please use the newer method instead.", FutureWarning, ) - self.add_adapter_fusion(adapter_names, adapter_fusion_config, override_kwargs) + adapter_fusion_config = AdapterFusionConfig.from_dict(adapter_fusion_config).replace(**override_kwargs) + self.add_adapter_fusion(adapter_names, adapter_fusion_config) def add_adapter_fusion( self, adapter_names: Union[Fuse, list], - adapter_fusion_config=None, - override_kwargs=None, + config=None, + overwrite_ok: bool = False, set_active: bool = False, ): """ @@ -274,26 +252,13 @@ def add_adapter_fusion( # TODO-V2 Allow nested items or directly pass Fuse block? if isinstance(adapter_names, Fuse): adapter_names = adapter_names.children - if not hasattr(self.config, "adapter_fusion"): - if override_kwargs is None: - override_kwargs = {} - if adapter_fusion_config is not None: - self.set_adapter_fusion_config(adapter_fusion_config, **override_kwargs) - else: - self.set_adapter_fusion_config(DEFAULT_ADAPTERFUSION_CONFIG) - elif hasattr(self.config, "adapter_fusion") and adapter_fusion_config is not None: - # This behavior may be a bit unintuitive as the given argument is ignored, but we can't throw an error because of the loader. - logger.warning("An AdapterFusion config has already been set and will NOT be overwritten") - - if not hasattr(self.config, "adapter_fusion_models"): - self.config.adapter_fusion_models = [] - if isinstance(adapter_names, list): - adapter_fusion_name = ",".join(adapter_names) - else: - adapter_fusion_name = adapter_names - if adapter_fusion_name not in self.config.adapter_fusion_models: - self.config.adapter_fusion_models.append(adapter_fusion_name) - self.base_model._add_fusion_layer(adapter_names) + if isinstance(config, dict): + config = AdapterFusionConfig.from_dict(config) # ensure config is ok and up-to-date + # In case adapter already exists and we allow overwriting, explicitly delete the existing one first + if overwrite_ok and self.config.adapters.get_fusion(adapter_names) is not None: + self.delete_adapter_fusion(adapter_names) + self.config.adapters.add_fusion(adapter_names, config=config) + self.base_model._add_fusion_layer(adapter_names) if set_active: if not isinstance(adapter_names, list): adapter_names = adapter_names.split(",") @@ -329,14 +294,14 @@ def delete_adapter_fusion(self, adapter_names: Union[Fuse, list]): else: adapter_fusion_name = adapter_names - if ( - not hasattr(self.config, "adapter_fusion_models") - or adapter_fusion_name not in self.config.adapter_fusion_models - ): + if adapter_fusion_name not in self.config.adapters.fusions: logger.info("No AdapterFusion '%s' found for deletion. Skipping.", adapter_fusion_name) return - self.config.adapter_fusion_models.remove(adapter_fusion_name) + del self.config.adapters.fusions[adapter_fusion_name] self.base_model._delete_fusion_layer(adapter_fusion_name) + # Reset active adapters if this was the active setup + if self.active_adapters == adapter_names: + self.active_adapters = None def save_adapter( self, @@ -367,6 +332,7 @@ def save_adapter_fusion( self, save_directory: str, adapter_names: list, + meta_dict: dict = None, custom_weights_loaders: Optional[List[WeightsLoader]] = None, ): """ @@ -382,7 +348,7 @@ def save_adapter_fusion( """ loader = AdapterFusionLoader(self) - loader.save(save_directory, adapter_names) + loader.save(save_directory, adapter_names, meta_dict) # save additional custom weights if custom_weights_loaders: for weights_loader in custom_weights_loaders: @@ -531,17 +497,17 @@ def save_all_adapter_fusions( Args: save_directory (str): Path to a directory where the adapters should be saved. """ - if not hasattr(self.config, "adapter_fusion_models"): - return - for name in self.config.adapter_fusion_models: - adapter_fusion_config = self.config.adapter_fusion + for name in self.config.adapters.fusions: + adapter_fusion_config = self.config.adapters.get_fusion(name) h = get_adapter_config_hash(adapter_fusion_config) save_path = join(save_directory, name) if meta_dict: meta_dict.update({"config_id": h}) else: meta_dict = {"config_id": h} - self.save_adapter_fusion(save_path, name, custom_weights_loaders=custom_weights_loaders) + self.save_adapter_fusion( + save_path, name, meta_dict=meta_dict, custom_weights_loaders=custom_weights_loaders + ) def freeze_model(self, freeze=True): """Freezes all weights of the model.""" diff --git a/src/transformers/adapters/modeling.py b/src/transformers/adapters/modeling.py index 2cb9ac3a71..3044b7ca1c 100644 --- a/src/transformers/adapters/modeling.py +++ b/src/transformers/adapters/modeling.py @@ -3,6 +3,8 @@ import torch from torch import nn +from .configuration import AdapterFusionConfig + class Activation_Function_Class(nn.Module): """ @@ -146,42 +148,40 @@ class BertFusion(nn.Module): Implementation of an AdapterFusion block. """ - def __init__(self, config): + def __init__( + self, + config: AdapterFusionConfig, + dense_size, + attention_probs_dropout_prob, + ): super(BertFusion, self).__init__() # if config.hidden_size % config.num_attention_heads != 0: # raise ValueError( # "The hidden size (%d) is not a multiple of the number of attention " # "heads (%d)" % (config.hidden_size, config.num_attention_heads)) self.config = config - self.output_attentions = config.output_attentions - self.dense_size = int(config.hidden_size) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.dense_size = dense_size + self.dropout = nn.Dropout(attention_probs_dropout_prob) - if ( - not self.config.adapter_fusion["query"] - and not self.config.adapter_fusion["key"] - and not self.config.adapter_fusion["value"] - ): + if not self.config["query"] and not self.config["key"] and not self.config["value"]: self.dense = nn.Linear(self.dense_size, 1) - if self.config.adapter_fusion["query"]: - self.query = nn.Linear(int(config.hidden_size), self.dense_size) + if self.config["query"]: + self.query = nn.Linear(self.dense_size, self.dense_size) self.query.apply(Adapter.init_bert_weights) - if self.config.adapter_fusion["key"]: + if self.config["key"]: self.key = nn.Linear(self.dense_size, self.dense_size) self.key.apply(Adapter.init_bert_weights) - if self.config.adapter_fusion["value"]: - self.value = nn.Linear(int(config.hidden_size), int(config.hidden_size), bias=False) + if self.config["value"]: + self.value = nn.Linear(self.dense_size, self.dense_size, bias=False) self.value.apply(Adapter.init_bert_weights) - if self.config.adapter_fusion["value_initialized"]: - self.value.weight.data = ( - torch.zeros(int(config.hidden_size), int(config.hidden_size)) + 0.000001 - ).fill_diagonal_(1.0) + if self.config["value_initialized"]: + self.value.weight.data = (torch.zeros(self.dense_size, self.dense_size) + 0.000001).fill_diagonal_(1.0) - if self.config.adapter_fusion["temperature"]: + if self.config["temperature"]: self.T = 50.0 else: self.T = 1.0 @@ -189,20 +189,20 @@ def __init__(self, config): def forward(self, query, key, value, residual): - if self.config.adapter_fusion["residual_before"]: + if self.config["residual_before"]: value += residual[:, :, None, :].repeat(1, 1, value.size(2), 1) - if self.config.adapter_fusion["query"]: + if self.config["query"]: query_layer = self.query(query) else: query_layer = query - if self.config.adapter_fusion["key"]: + if self.config["key"]: key_layer = self.key(key) else: key_layer = key - if self.config.adapter_fusion["value"] and self.config.adapter_fusion["value_before_softmax"]: + if self.config["value"] and self.config["value_before_softmax"]: # key/value have dims => batch, toks, number-of-adapters, feats value_layer = self.value(value) else: @@ -222,13 +222,13 @@ def forward(self, query, key, value, residual): context_layer = torch.squeeze(torch.matmul(attention_probs.unsqueeze(2), value_layer), dim=2) - if self.config.adapter_fusion["value"] and not self.config.adapter_fusion["value_before_softmax"]: + if self.config["value"] and not self.config["value_before_softmax"]: # key/value have dims => batch, toks, number-of-adapters, feats context_layer = self.value(context_layer) else: context_layer = context_layer - if not self.config.adapter_fusion["residual_before"]: + if not self.config["residual_before"]: context_layer += residual return context_layer diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8f67ae3b87..12a362267d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -55,6 +55,7 @@ from torch.utils.data.sampler import RandomSampler, SequentialSampler from . import __version__ +from .adapters.composition import AdapterCompositionBlock, Fuse from .configuration_utils import PretrainedConfig from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .debug_utils import DebugOption, DebugUnderflowOverflow @@ -407,10 +408,18 @@ def __init__( else: model_freezed = False if model_freezed and self.model.active_adapters: + # Check if training AdapterFusion + self.train_adapter_fusion = ( + isinstance(self.model.active_adapters, Fuse) + or isinstance(self.model.active_adapters, AdapterCompositionBlock) + and any([isinstance(child, Fuse) for child in self.model.active_adapters.children]) + ) + # Configure model saving self.do_save_full_model = False self.do_save_adapters = True - self.do_save_adapter_fusion = True + self.do_save_adapter_fusion = self.train_adapter_fusion else: + self.train_adapter_fusion = False self.do_save_full_model = True self.do_save_adapters = False self.do_save_adapter_fusion = False @@ -806,9 +815,9 @@ def create_optimizer(self): if self.optimizer is None: decay_parameters = get_parameter_names(self.model, [nn.LayerNorm]) decay_parameters = [name for name in decay_parameters if "bias" not in name] - if hasattr(self.model, "config") and hasattr(self.model.config, "adapter_fusion_models"): - no_decay = [f"adapter_fusion_layer.{n}.value" for n in self.model.config.adapter_fusion_models] - decay_parameters = [name for name in decay_parameters if name not in no_decay] + if hasattr(self.model, "config") and hasattr(self.model.config, "adapters"): + match_str = r"adapter_fusion_layer\..*\.value" + decay_parameters = [name for name in decay_parameters if not re.match(match_str, name)] optimizer_grouped_parameters = [ { "params": [p for n, p in self.model.named_parameters() if n in decay_parameters], @@ -1325,10 +1334,7 @@ def train( and (step + 1) == steps_in_epoch ): # apply adapter fusion weight regularization on the value matrix - if ( - hasattr(self.model.config, "adapter_fusion") - and self.model.config.adapter_fusion["regularization"] - ): + if self.train_adapter_fusion: fusion_reg_loss = self.model.base_model.get_fusion_regularization_loss() fusion_reg_loss.backward() @@ -1438,8 +1444,7 @@ def train( f"Loading best adapter fusion(s) from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." ) # attempt to re-load all adapter fusions from checkpoint - fusion_models = getattr(self.model.config, "adapter_fusion_models", []) - for fusion in fusion_models: + for fusion in self.model.config.adapters.fusions: fusion_dir = os.path.join(self.state.best_model_checkpoint, fusion) if os.path.exists(fusion_dir): self.model.load_adapter_fusion(fusion_dir) diff --git a/tests/test_adapter_fusion_common.py b/tests/test_adapter_fusion_common.py index a8cbfebb12..592c40eaec 100644 --- a/tests/test_adapter_fusion_common.py +++ b/tests/test_adapter_fusion_common.py @@ -53,7 +53,7 @@ def test_add_adapter_fusion_different_config(self): # correct fusion model.add_adapter_fusion(["a", "b"]) - self.assertIn("a,b", model.config.adapter_fusion_models) + self.assertIn("a,b", model.config.adapters.fusions) # failing fusion self.assertRaises(ValueError, lambda: model.add_adapter_fusion(["a", "c"])) @@ -69,10 +69,10 @@ def test_delete_adapter_fusion(self): self.assertTrue(name2 in model.config.adapters) model.add_adapter_fusion([name1, name2]) - self.assertTrue(",".join([name1, name2]) in model.config.adapter_fusion_models) + self.assertTrue(",".join([name1, name2]) in model.config.adapters.fusions) model.delete_adapter_fusion([name1, name2]) - self.assertFalse(",".join([name1, name2]) in model.config.adapter_fusion_models) + self.assertFalse(",".join([name1, name2]) in model.config.adapters.fusions) def test_load_adapter_fusion(self): for adater_fusion_config_name, adapter_fusion_config in ADAPTERFUSION_CONFIG_MAP.items(): @@ -97,7 +97,7 @@ def test_load_adapter_fusion(self): model2.load_adapter_fusion(temp_dir, set_active=True) # check if adapter was correctly loaded - self.assertTrue(model1.config.adapter_fusion_models == model2.config.adapter_fusion_models) + self.assertEqual(model1.config.adapters.fusions.keys(), model2.config.adapters.fusions.keys()) # check equal output in_data = self.get_input_samples((1, 128), config=model1.config) @@ -121,7 +121,7 @@ def test_load_full_model_fusion(self): model2 = AutoModel.from_pretrained(temp_dir) # check if AdapterFusion was correctly loaded - self.assertTrue(model1.config.adapter_fusion_models == model2.config.adapter_fusion_models) + self.assertTrue(model1.config.adapters.fusions == model2.config.adapters.fusions) # check equal output in_data = self.get_input_samples((1, 128), config=model1.config) @@ -141,6 +141,6 @@ def test_model_config_serialization_fusion(self): model = AutoModel.from_config(self.config()) model.add_adapter("test1") model.add_adapter("test2") - model.add_adapter_fusion(["test1", "test2"], adapter_fusion_config=v) + model.add_adapter_fusion(["test1", "test2"], config=v) # should not raise an exception model.config.to_json_string() diff --git a/tests/test_adapter_trainer.py b/tests/test_adapter_trainer.py index 7914f45c2b..236ce9e170 100644 --- a/tests/test_adapter_trainer.py +++ b/tests/test_adapter_trainer.py @@ -131,8 +131,10 @@ def test_auto_set_save_adapters(self): intermediate_size=37, ) ) - model.add_adapter("adapter") - model.train_adapter("adapter") + model.add_adapter("adapter1") + model.add_adapter("adapter2") + model.add_adapter_fusion(Fuse("adapter1", "adapter2")) + model.train_adapter_fusion(Fuse("adapter1", "adapter2")) training_args = TrainingArguments( output_dir="./examples", diff --git a/tests/test_adapter_training.py b/tests/test_adapter_training.py index 4eef4815e8..3dd6fa367d 100644 --- a/tests/test_adapter_training.py +++ b/tests/test_adapter_training.py @@ -108,6 +108,18 @@ def test_train_adapter_fusion(self): state_dict_pre = copy.deepcopy(model.state_dict()) + # Since our config has a value matrix, make sure it is regularized. + # We do this by patching the fusion regularization function. + regularization_called = False + orig_fusion_regularization_loss = model.base_model.get_fusion_regularization_loss + + def patched_fusion_reg_loss(): + nonlocal regularization_called + regularization_called = True + return orig_fusion_regularization_loss() + + model.base_model.get_fusion_regularization_loss = patched_fusion_reg_loss + # setup dataset data_args = GlueDataTrainingArguments( task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True @@ -130,3 +142,5 @@ def test_train_adapter_fusion(self): self.assertFalse(torch.equal(v1, v2), k1) else: self.assertTrue(torch.equal(v1, v2), k1) + + self.assertTrue(regularization_called)