diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 90cb68c719..877ce4e9a2 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1384,6 +1384,7 @@ "StaticAdapterFusionConfig", ] _import_structure["adapters.heads"] = ["ModelWithFlexibleHeadsAdaptersMixin"] + _import_structure["adapters.invertible_adapters_mixin"] = ["InvertibleAdaptersMixin"] _import_structure["adapters.layer"] = ["AdapterLayerBaseMixin"] _import_structure["adapters.loading"] = [ "AdapterFusionLoader", @@ -1393,7 +1394,6 @@ "WeightsLoaderHelper", ] _import_structure["adapters.model_mixin"] = [ - "InvertibleAdaptersMixin", "ModelAdaptersMixin", "ModelConfigAdaptersMixin", "ModelWithHeadsAdaptersMixin", @@ -3170,6 +3170,7 @@ StaticAdapterFusionConfig, ) from .adapters.heads import ModelWithFlexibleHeadsAdaptersMixin + from .adapters.invertible_adapters_mixin import InvertibleAdaptersMixin from .adapters.layer import AdapterLayerBaseMixin from .adapters.loading import ( AdapterFusionLoader, @@ -3178,12 +3179,7 @@ WeightsLoader, WeightsLoaderHelper, ) - from .adapters.model_mixin import ( - InvertibleAdaptersMixin, - ModelAdaptersMixin, - ModelConfigAdaptersMixin, - ModelWithHeadsAdaptersMixin, - ) + from .adapters.model_mixin import ModelAdaptersMixin, ModelConfigAdaptersMixin, ModelWithHeadsAdaptersMixin from .adapters.trainer import AdapterTrainer, Seq2SeqAdapterTrainer from .adapters.training import AdapterArguments, MultiLingAdapterArguments from .adapters.utils import ( diff --git a/src/transformers/adapters/invertible_adapters_mixin.py b/src/transformers/adapters/invertible_adapters_mixin.py new file mode 100644 index 0000000000..8ddb142977 --- /dev/null +++ b/src/transformers/adapters/invertible_adapters_mixin.py @@ -0,0 +1,67 @@ +from torch import nn + +from .modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock + + +class InvertibleAdaptersMixin: + """Mixin for Transformer models adding invertible adapters.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.invertible_adapters = nn.ModuleDict(dict()) + + def add_invertible_adapter(self, adapter_name: str): + """ + Adds an invertible adapter module for the adapter with the given name. If the given adapter does not specify an + invertible adapter config, this method does nothing. + + Args: + adapter_name (str): The name of the adapter for which to add an invertible adapter module. + """ + if adapter_name in self.invertible_adapters: + raise ValueError(f"Model already contains an adapter module for '{adapter_name}'.") + adapter_config = self.config.adapters.get(adapter_name) + if adapter_config and adapter_config["inv_adapter"]: + if adapter_config["inv_adapter"] == "nice": + inv_adap = NICECouplingBlock( + [[self.config.hidden_size]], + non_linearity=adapter_config["non_linearity"], + reduction_factor=adapter_config["inv_adapter_reduction_factor"], + ) + elif adapter_config["inv_adapter"] == "glow": + inv_adap = GLOWCouplingBlock( + [[self.config.hidden_size]], + non_linearity=adapter_config["non_linearity"], + reduction_factor=adapter_config["inv_adapter_reduction_factor"], + ) + else: + raise ValueError(f"Invalid invertible adapter type '{adapter_config['inv_adapter']}'.") + self.invertible_adapters[adapter_name] = inv_adap + self.invertible_adapters[adapter_name].apply(Adapter.init_bert_weights) + + def delete_invertible_adapter(self, adapter_name: str): + if adapter_name in self.invertible_adapters: + del self.invertible_adapters[adapter_name] + + def get_invertible_adapter(self): + # TODO: Currently no fusion over invertible adapters, takes only very first language adapter position + if self.config.adapters.active_setup is not None and len(self.config.adapters.active_setup) > 0: + first_adapter = self.config.adapters.active_setup.first() + if first_adapter in self.invertible_adapters: + return self.invertible_adapters[first_adapter] + return None + + def enable_invertible_adapters(self, adapter_names): + for adapter_name in adapter_names: + if adapter_name in self.invertible_adapters: + for param in self.invertible_adapters[adapter_name].parameters(): + param.requires_grad = True + + def invertible_adapters_forward(self, hidden_states, rev=False): + # TODO: Currently no fusion over invertible adapters, takes only very first language adapter position + if self.config.adapters.active_setup is not None and len(self.config.adapters.active_setup) > 0: + first_adapter = self.config.adapters.active_setup.first() + if first_adapter in self.invertible_adapters: + hidden_states = self.invertible_adapters[first_adapter](hidden_states, rev=rev) + + return hidden_states diff --git a/src/transformers/adapters/model_mixin.py b/src/transformers/adapters/model_mixin.py index 4a4726fa4a..b686c2dcb4 100644 --- a/src/transformers/adapters/model_mixin.py +++ b/src/transformers/adapters/model_mixin.py @@ -12,77 +12,12 @@ from .configuration import AdapterConfig, AdapterFusionConfig, ModelAdaptersConfig, get_adapter_config_hash from .hub_mixin import PushAdapterToHubMixin from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader -from .modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock from .utils import EMBEDDING_FILE, TOKENIZER_PATH, inherit_doc logger = logging.getLogger(__name__) -class InvertibleAdaptersMixin: - """Mixin for Transformer models adding invertible adapters.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.invertible_adapters = nn.ModuleDict(dict()) - - def add_invertible_adapter(self, adapter_name: str): - """ - Adds an invertible adapter module for the adapter with the given name. If the given adapter does not specify an - invertible adapter config, this method does nothing. - - Args: - adapter_name (str): The name of the adapter for which to add an invertible adapter module. - """ - if adapter_name in self.invertible_adapters: - raise ValueError(f"Model already contains an adapter module for '{adapter_name}'.") - adapter_config = self.config.adapters.get(adapter_name) - if adapter_config and adapter_config["inv_adapter"]: - if adapter_config["inv_adapter"] == "nice": - inv_adap = NICECouplingBlock( - [[self.config.hidden_size]], - non_linearity=adapter_config["non_linearity"], - reduction_factor=adapter_config["inv_adapter_reduction_factor"], - ) - elif adapter_config["inv_adapter"] == "glow": - inv_adap = GLOWCouplingBlock( - [[self.config.hidden_size]], - non_linearity=adapter_config["non_linearity"], - reduction_factor=adapter_config["inv_adapter_reduction_factor"], - ) - else: - raise ValueError(f"Invalid invertible adapter type '{adapter_config['inv_adapter']}'.") - self.invertible_adapters[adapter_name] = inv_adap - self.invertible_adapters[adapter_name].apply(Adapter.init_bert_weights) - - def delete_invertible_adapter(self, adapter_name: str): - if adapter_name in self.invertible_adapters: - del self.invertible_adapters[adapter_name] - - def get_invertible_adapter(self): - # TODO: Currently no fusion over invertible adapters, takes only very first language adapter position - if self.config.adapters.active_setup is not None and len(self.config.adapters.active_setup) > 0: - first_adapter = self.config.adapters.active_setup.first() - if first_adapter in self.invertible_adapters: - return self.invertible_adapters[first_adapter] - return None - - def enable_invertible_adapters(self, adapter_names): - for adapter_name in adapter_names: - if adapter_name in self.invertible_adapters: - for param in self.invertible_adapters[adapter_name].parameters(): - param.requires_grad = True - - def invertible_adapters_forward(self, hidden_states, rev=False): - # TODO: Currently no fusion over invertible adapters, takes only very first language adapter position - if self.config.adapters.active_setup is not None and len(self.config.adapters.active_setup) > 0: - first_adapter = self.config.adapters.active_setup.first() - if first_adapter in self.invertible_adapters: - hidden_states = self.invertible_adapters[first_adapter](hidden_states, rev=rev) - - return hidden_states - - class ModelConfigAdaptersMixin(ABC): """ Mixin for model config classes, adding support for adapters. diff --git a/src/transformers/adapters/models/bert.py b/src/transformers/adapters/models/bert.py index 85b539dfef..b8ea4a9628 100644 --- a/src/transformers/adapters/models/bert.py +++ b/src/transformers/adapters/models/bert.py @@ -15,8 +15,9 @@ QuestionAnsweringHead, TaggingHead, ) +from ..invertible_adapters_mixin import InvertibleAdaptersMixin from ..layer import AdapterLayerBaseMixin -from ..model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin +from ..model_mixin import ModelAdaptersMixin logger = logging.getLogger(__name__) diff --git a/src/transformers/adapters/models/distilbert.py b/src/transformers/adapters/models/distilbert.py index d7a30b97d9..a1ecd3e4f6 100644 --- a/src/transformers/adapters/models/distilbert.py +++ b/src/transformers/adapters/models/distilbert.py @@ -4,7 +4,8 @@ from torch import nn from ..composition import AdapterCompositionBlock, parse_composition -from ..model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin +from ..invertible_adapters_mixin import InvertibleAdaptersMixin +from ..model_mixin import ModelAdaptersMixin from .bert import BertEncoderAdaptersMixin, BertModelHeadsMixin, BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin diff --git a/src/transformers/adapters/models/gpt2.py b/src/transformers/adapters/models/gpt2.py index d79b26ae0b..e9402e242d 100644 --- a/src/transformers/adapters/models/gpt2.py +++ b/src/transformers/adapters/models/gpt2.py @@ -5,7 +5,8 @@ from ..composition import AdapterCompositionBlock, parse_composition from ..heads import CausalLMHead, ClassificationHead, MultiLabelClassificationHead, TaggingHead -from ..model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin +from ..invertible_adapters_mixin import InvertibleAdaptersMixin +from ..model_mixin import ModelAdaptersMixin from .bert import ( BertEncoderAdaptersMixin, BertOutputAdaptersMixin, diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index c988d8ee48..c4f052a60e 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -25,7 +25,8 @@ from torch.nn import CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin +from ...adapters.invertible_adapters_mixin import InvertibleAdaptersMixin +from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin from ...adapters.models.bart import ( BartDecoderLayerAdaptersMixin, BartEncoderDecoderAdaptersMixin, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index b98a316f5e..8ceea19cd2 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -24,7 +24,8 @@ from torch.nn import CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin +from ...adapters.invertible_adapters_mixin import InvertibleAdaptersMixin +from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin from ...adapters.models.bart import ( BartDecoderLayerAdaptersMixin, BartEncoderDecoderAdaptersMixin, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index bdf245cde5..f91af1c2e6 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -25,7 +25,8 @@ from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN -from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin +from ...adapters.invertible_adapters_mixin import InvertibleAdaptersMixin +from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin from ...adapters.models.t5 import ( T5BlockAdaptersMixin, T5CrossAttentionLayerAdaptersMixin,