diff --git a/adapter_docs/classes/bert.rst b/adapter_docs/classes/bert.rst index fc1430b511..073cf5db59 100644 --- a/adapter_docs/classes/bert.rst +++ b/adapter_docs/classes/bert.rst @@ -31,6 +31,13 @@ BertModel :members: +BertModelWithHeads +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BertModelWithHeads + :members: + + BertForPreTraining ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/adapter_docs/classes/distilbert.rst b/adapter_docs/classes/distilbert.rst new file mode 100644 index 0000000000..f618300153 --- /dev/null +++ b/adapter_docs/classes/distilbert.rst @@ -0,0 +1,69 @@ +DistilBERT +=========== + +The DistilBERT model was proposed in the blog post +`Smaller, faster, cheaper, lighter: Introducing DistilBERT, a distilled version of BERT `__, +and the paper `DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter `__. +DistilBERT is a small, fast, cheap and light Transformer model trained by distilling Bert base. It has 40% less +parameters than `bert-base-uncased`, runs 60% faster while preserving over 95% of Bert's performances as measured on +the GLUE language understanding benchmark. + +.. note:: + This class is nearly identical to the PyTorch implementation of DistilBERT in Huggingface Transformers. + For more information, visit `the corresponding section in their documentation `_. + + +DistilBertConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DistilBertConfig + :members: + + +DistilBertTokenizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DistilBertTokenizer + :members: + + +DistilBertTokenizerFast +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DistilBertTokenizerFast + :members: + + +DistilBertModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DistilBertModel + :members: + + +DistilBertModelWithHeads +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DistilBertModelWithHeads + :members: + + +DistilBertForMaskedLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DistilBertForMaskedLM + :members: + + +DistilBertForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DistilBertForSequenceClassification + :members: + + +DistilBertForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DistilBertForQuestionAnswering + :members: diff --git a/adapter_docs/classes/roberta.rst b/adapter_docs/classes/roberta.rst index 13b21c10f3..12a96150c6 100644 --- a/adapter_docs/classes/roberta.rst +++ b/adapter_docs/classes/roberta.rst @@ -31,6 +31,13 @@ RobertaModel :members: +RobertaModelWithHeads +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.RobertaModelWithHeads + :members: + + RobertaForMaskedLM ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/adapter_docs/classes/xlmroberta.rst b/adapter_docs/classes/xlmroberta.rst index 305cea98be..6ed7f3d1dc 100644 --- a/adapter_docs/classes/xlmroberta.rst +++ b/adapter_docs/classes/xlmroberta.rst @@ -32,6 +32,13 @@ XLMRobertaModel :members: +XLMRobertaModelWithHeads +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.XLMRobertaModelWithHeads + :members: + + XLMRobertaForMaskedLM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/adapter_docs/index.rst b/adapter_docs/index.rst index 3245b1338a..c30b2deeef 100644 --- a/adapter_docs/index.rst +++ b/adapter_docs/index.rst @@ -58,6 +58,7 @@ Currently, we support the PyTorch versions of all models listed in the *Supporte classes/bert classes/roberta classes/xlmroberta + classes/distilbert Citation diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a85cae9c5d..a66c468b73 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -337,6 +337,7 @@ DistilBertPreTrainedModel, DistilBertForMaskedLM, DistilBertModel, + DistilBertModelWithHeads, DistilBertForMultipleChoice, DistilBertForSequenceClassification, DistilBertForQuestionAnswering, diff --git a/src/transformers/adapter_bert.py b/src/transformers/adapter_bert.py index a156d16d25..eacfe7c759 100644 --- a/src/transformers/adapter_bert.py +++ b/src/transformers/adapter_bert.py @@ -5,8 +5,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from .adapter_config import DEFAULT_ADAPTER_CONFIG, AdapterType -from .adapter_model_mixin import ModelAdaptersMixin, ModelWithHeadsAdaptersMixin -from .adapter_modeling import Activation_Function_Class, Adapter, BertFusion, GLOWCouplingBlock, NICECouplingBlock +from .adapter_model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin, ModelWithHeadsAdaptersMixin +from .adapter_modeling import Activation_Function_Class, Adapter, BertFusion from .adapter_utils import flatten_adapter_names, parse_adapter_names @@ -40,6 +40,11 @@ class BertSelfOutputAdaptersMixin: """Adds adapters to the BertSelfOutput module. """ + # override this property if layer norm has a different name + @property + def layer_norm(self): + return self.LayerNorm + def _init_adapter_modules(self): self.attention_text_task_adapters = nn.ModuleDict(dict()) self.adapter_fusion_layer = nn.ModuleDict(dict()) @@ -115,7 +120,7 @@ def get_adapter_preparams( query = hidden_states if adapter_config["original_ln_before"]: - hidden_states = self.LayerNorm(hidden_states + input_tensor) + hidden_states = self.layer_norm(hidden_states + input_tensor) if not adapter_config["residual_before_ln"]: residual = hidden_states @@ -227,10 +232,10 @@ def adapters_forward(self, hidden_states, input_tensor, adapter_names=None): last_config = self.config.adapters.get(adapter_names[-1][-1]) if last_config["original_ln_after"]: - hidden_states = self.LayerNorm(hidden_states + input_tensor) + hidden_states = self.layer_norm(hidden_states + input_tensor) else: - hidden_states = self.LayerNorm(hidden_states + input_tensor) + hidden_states = self.layer_norm(hidden_states + input_tensor) return hidden_states @@ -239,9 +244,12 @@ class BertOutputAdaptersMixin: """Adds adapters to the BertOutput module. """ + # override this property if layer norm has a different name + @property + def layer_norm(self): + return self.LayerNorm + def _init_adapter_modules(self): - # self.bert_adapter_att = BertAdapterAttention(config) - # self.bert_adapter_att = SimpleAdapterWeightingSentLvl(config) self.adapter_fusion_layer = nn.ModuleDict(dict()) self.layer_text_task_adapters = nn.ModuleDict(dict()) self.layer_text_lang_adapters = nn.ModuleDict(dict()) @@ -311,7 +319,7 @@ def get_adapter_preparams( query = hidden_states if adapter_config["original_ln_before"]: - hidden_states = self.LayerNorm(hidden_states + input_tensor) + hidden_states = self.layer_norm(hidden_states + input_tensor) if not adapter_config["residual_before_ln"]: residual = hidden_states @@ -424,10 +432,10 @@ def adapters_forward(self, hidden_states, input_tensor, adapter_names=None): last_config = self.config.adapters.get(adapter_names[-1][-1]) if last_config["original_ln_after"]: - hidden_states = self.LayerNorm(hidden_states + input_tensor) + hidden_states = self.layer_norm(hidden_states + input_tensor) else: - hidden_states = self.LayerNorm(hidden_states + input_tensor) + hidden_states = self.layer_norm(hidden_states + input_tensor) return hidden_states @@ -469,7 +477,7 @@ def enable_adapters(self, adapter_names: list, unfreeze_adapters: bool, unfreeze layer.enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention) -class BertModelAdaptersMixin(ModelAdaptersMixin): +class BertModelAdaptersMixin(InvertibleAdaptersMixin, ModelAdaptersMixin): """Adds adapters to the BertModel module. """ @@ -477,7 +485,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def _init_adapter_modules(self): - self.invertible_lang_adapters = nn.ModuleDict(dict()) + super()._init_adapter_modules() # language adapters for language in self.config.adapters.adapter_list(AdapterType.text_lang): @@ -496,12 +504,8 @@ def train_adapter(self, adapter_names: list): self.train() self.freeze_model(True) adapter_names_flat = flatten_adapter_names(adapter_names) - self.encoder.enable_adapters(adapter_names_flat, True, False) - # unfreeze invertible adapters for invertible adapters - for adapter_name in adapter_names_flat: - if adapter_name in self.invertible_lang_adapters: - for param in self.invertible_lang_adapters[adapter_name].parameters(): - param.requires_grad = True + self.encoder.enable_adapters(adapter_names, True, False) + self.enable_invertible_adapters(adapter_names_flat) # use the adapters to be trained by default in every forward pass self.set_active_adapters(adapter_names) @@ -535,35 +539,7 @@ def add_adapter(self, adapter_name: str, adapter_type: AdapterType, config=None) if adapter_type == AdapterType.text_lang: self.add_invertible_lang_adapter(adapter_name) - def add_invertible_lang_adapter(self, language): - if language in self.invertible_lang_adapters: - raise ValueError(f"Model already contains an adapter module for '{language}'.") - inv_adap_config = self.config.adapters.get(language)["invertible_adapter"] - if inv_adap_config["block_type"] == "nice": - inv_adap = NICECouplingBlock( - [[self.config.hidden_size]], - non_linearity=inv_adap_config["non_linearity"], - reduction_factor=inv_adap_config["reduction_factor"], - ) - elif inv_adap_config["block_type"] == "glow": - inv_adap = GLOWCouplingBlock( - [[self.config.hidden_size]], - non_linearity=inv_adap_config["non_linearity"], - reduction_factor=inv_adap_config["reduction_factor"], - ) - else: - raise ValueError(f"Invalid invertible adapter type '{inv_adap_config['block_type']}'.") - self.invertible_lang_adapters[language] = inv_adap - self.invertible_lang_adapters[language].apply(Adapter.init_bert_weights) - - def get_invertible_lang_adapter(self, language): - if language in self.invertible_lang_adapters: - return self.invertible_lang_adapters[language] - else: - return None - - def add_fusion_layer(self, adapter_names): - """See BertModel.add_attention_layer""" + def _add_fusion_layer(self, adapter_names): self.encoder.add_fusion_layer(adapter_names) diff --git a/src/transformers/adapter_distilbert.py b/src/transformers/adapter_distilbert.py new file mode 100644 index 0000000000..6466e91cad --- /dev/null +++ b/src/transformers/adapter_distilbert.py @@ -0,0 +1,141 @@ +from torch import nn + +from .adapter_bert import ( + BertEncoderAdaptersMixin, + BertModelHeadsMixin, + BertOutputAdaptersMixin, + BertSelfOutputAdaptersMixin, +) +from .adapter_config import DEFAULT_ADAPTER_CONFIG +from .adapter_model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin +from .adapter_utils import AdapterType, flatten_adapter_names + + +class DistilBertSelfAttentionAdaptersModule(nn.Module, BertSelfOutputAdaptersMixin): + """Adds attention adapters to the Transformer module of DistilBert. + """ + + def __init__(self, parent): + super().__init__() + self._layer_norm = parent.sa_layer_norm + self.config = parent.config + + @property + def layer_norm(self): + return self._layer_norm + + +class DistilBertOutputAdaptersModule(nn.Module, BertOutputAdaptersMixin): + """Adds output adapters to the Transformer module of DistilBert. + """ + + def __init__(self, parent): + super().__init__() + self._layer_norm = parent.output_layer_norm + self.config = parent.config + + @property + def layer_norm(self): + return self._layer_norm + + +class DistilBertTransfomerBlockAdaptersMixin: + """Adds adapters to the TransformerBlock module of DistilBert. + """ + + def _init_adapter_modules(self): + self.attention_adapters = DistilBertSelfAttentionAdaptersModule(self) + self.output_adapters = DistilBertOutputAdaptersModule(self) + self.attention_adapters._init_adapter_modules() + self.output_adapters._init_adapter_modules() + + def add_fusion_layer(self, adapter_names): + self.attention_adapters.add_fusion_layer(adapter_names) + self.output_adapters.add_fusion_layer(adapter_names) + + def add_adapter(self, adapter_name: str, adapter_type: AdapterType): + self.attention_adapters.add_adapter(adapter_name, adapter_type) + self.output_adapters.add_adapter(adapter_name, adapter_type) + + def enable_adapters(self, adapter_names: list, unfreeze_adapters: bool, unfreeze_attention: bool): + self.attention_adapters.enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention) + self.output_adapters.enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention) + + +class DistilBertTransformerAdaptersMixin(BertEncoderAdaptersMixin): + """Adds adapters to the Transformer module of DistilBert. + """ + + pass + + +class DistilBertModelAdaptersMixin(InvertibleAdaptersMixin, ModelAdaptersMixin): + """Adds adapters to the DistilBert module. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _init_adapter_modules(self): + super()._init_adapter_modules() + + # language adapters + for language in self.config.adapters.adapter_list(AdapterType.text_lang): + self.transformer.add_adapter(language, AdapterType.text_lang) + self.add_invertible_lang_adapter(language) + # task adapters + for task in self.config.adapters.adapter_list(AdapterType.text_task): + self.transformer.add_adapter(task, AdapterType.text_task) + # fusion + if hasattr(self.config, "fusion_models"): + for fusion_adapter_names in self.config.fusion_models: + self.transformer.add_fusion_layer(fusion_adapter_names) + + def train_adapter(self, adapter_names: list): + """Sets the model in mode for training the given adapters.""" + self.train() + self.freeze_model(True) + adapter_names_flat = flatten_adapter_names(adapter_names) + self.transformer.enable_adapters(adapter_names, True, False) + self.enable_invertible_adapters(adapter_names_flat) + # use the adapters to be trained by default in every forward pass + self.set_active_adapters(adapter_names) + + def train_fusion(self, adapter_names: list): + """Sets the model in mode for training of adapter fusion determined by a list of adapter names.""" + self.train() + self.freeze_model(True) + adapter_names_flat = flatten_adapter_names(adapter_names) + self.transformer.enable_adapters(adapter_names_flat, False, True) + # use the adapters to be trained by default in every forward pass + self.set_active_adapters(adapter_names) + + def add_adapter(self, adapter_name: str, adapter_type: AdapterType, config=None): + """Adds a new adapter module of the specified type to the model. + + Args: + adapter_name (str): The name of the adapter module to be added. + adapter_type (AdapterType): The adapter type. + config (str or dict or AdapterConfig, optional): The adapter configuration, can be either: + - the string identifier of a pre-defined configuration dictionary + - a configuration dictionary specifying the full config + - if not given, the default configuration for this adapter type will be used + """ + if not AdapterType.has(adapter_type): + raise ValueError("Invalid adapter type {}".format(adapter_type)) + if not self.config.adapters.get_config(adapter_type): + self.config.adapters.set_config(adapter_type, config or DEFAULT_ADAPTER_CONFIG) + self.config.adapters.add(adapter_name, adapter_type, config=config) + self.transformer.add_adapter(adapter_name, adapter_type) + if adapter_type == AdapterType.text_lang: + self.add_invertible_lang_adapter(adapter_name) + + def _add_fusion_layer(self, adapter_names): + self.transformer.add_fusion_layer(adapter_names) + + +class DistilBertModelHeadsMixin(BertModelHeadsMixin): + """Adds heads to a DistilBert model. + """ + + pass diff --git a/src/transformers/adapter_model_mixin.py b/src/transformers/adapter_model_mixin.py index 66c13912fb..585d1cee9b 100644 --- a/src/transformers/adapter_model_mixin.py +++ b/src/transformers/adapter_model_mixin.py @@ -6,6 +6,7 @@ from typing import Callable, List, Mapping, Optional, Tuple, Union import torch +from torch import nn from .adapter_config import ( ADAPTERFUSION_CONFIG_MAP, @@ -16,6 +17,7 @@ build_full_config, get_adapter_config_hash, ) +from .adapter_modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock from .adapter_utils import ( ADAPTERFUSION_CONFIG_NAME, ADAPTERFUSION_WEIGHTS_NAME, @@ -624,6 +626,59 @@ def load(self, save_directory, load_as=None, loading_info=None): return save_directory, head_name +class InvertibleAdaptersMixin: + """Mixin for Transformer models adding invertible adapters. + """ + + def _init_adapter_modules(self): + self.invertible_lang_adapters = nn.ModuleDict(dict()) + + def add_invertible_lang_adapter(self, language): + if language in self.invertible_lang_adapters: + raise ValueError(f"Model already contains an adapter module for '{language}'.") + inv_adap_config = self.config.adapters.get(language)["invertible_adapter"] + if inv_adap_config["block_type"] == "nice": + inv_adap = NICECouplingBlock( + [[self.config.hidden_size]], + non_linearity=inv_adap_config["non_linearity"], + reduction_factor=inv_adap_config["reduction_factor"], + ) + elif inv_adap_config["block_type"] == "glow": + inv_adap = GLOWCouplingBlock( + [[self.config.hidden_size]], + non_linearity=inv_adap_config["non_linearity"], + reduction_factor=inv_adap_config["reduction_factor"], + ) + else: + raise ValueError(f"Invalid invertible adapter type '{inv_adap_config['block_type']}'.") + self.invertible_lang_adapters[language] = inv_adap + self.invertible_lang_adapters[language].apply(Adapter.init_bert_weights) + + def get_invertible_lang_adapter(self, adapter_names): + # TODO: Currently no fusion over invertible adapters, takes only very first language adapter position + if adapter_names is not None and len(adapter_names) > 0: + adapter_names = parse_adapter_names(adapter_names) + language = adapter_names[0][0] + if language in self.invertible_lang_adapters: + return self.invertible_lang_adapters[language] + return None + + def enable_invertible_adapters(self, adapter_names): + for adapter_name in adapter_names: + if adapter_name in self.invertible_lang_adapters: + for param in self.invertible_lang_adapters[adapter_name].parameters(): + param.requires_grad = True + + def invertible_adapters_forward(self, hidden_states, adapter_names=None, rev=False): + # TODO: Currently no fusion over invertible adapters, takes only very first language adapter position + if adapter_names is not None and len(adapter_names) > 0: + adapter_names = parse_adapter_names(adapter_names) + if adapter_names[0][0] in self.invertible_lang_adapters: + hidden_states = self.invertible_lang_adapters[adapter_names[0][0]](hidden_states, rev=rev) + + return hidden_states + + class ModelAdaptersMixin(ABC): """Mixin for transformer models adding support for loading/ saving adapters.""" @@ -755,7 +810,7 @@ def add_fusion(self, adapter_names, adapter_fusion_config=None, override_kwargs= adapter_fusion_name = adapter_names if adapter_fusion_name not in self.config.adapter_fusion_models: self.config.adapter_fusion_models.append(adapter_fusion_name) - self.base_model.add_fusion_layer(adapter_names) + self.base_model._add_fusion_layer(adapter_names) def save_adapter( self, diff --git a/src/transformers/configuration_distilbert.py b/src/transformers/configuration_distilbert.py index 9bd9baf228..13e624ada4 100644 --- a/src/transformers/configuration_distilbert.py +++ b/src/transformers/configuration_distilbert.py @@ -17,6 +17,7 @@ import logging +from .adapter_config import ModelAdaptersConfig from .configuration_utils import PretrainedConfig @@ -126,6 +127,13 @@ def __init__( self.qa_dropout = qa_dropout self.seq_classif_dropout = seq_classif_dropout + # adapter configuration + adapter_config_dict = kwargs.pop("adapters", None) + if adapter_config_dict: + self.adapters = ModelAdaptersConfig(**adapter_config_dict) + else: + self.adapters = ModelAdaptersConfig() + @property def hidden_size(self): return self.dim @@ -137,3 +145,11 @@ def num_attention_heads(self): @property def num_hidden_layers(self): return self.n_layers + + @property + def hidden_dropout_prob(self): + return self.dropout + + @property + def attention_probs_dropout_prob(self): + return self.attention_dropout diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 875530a55b..cbbe2ca9ca 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -86,6 +86,7 @@ DistilBertForSequenceClassification, DistilBertForTokenClassification, DistilBertModel, + DistilBertModelWithHeads, ) from .modeling_electra import ( ElectraForMaskedLM, @@ -200,6 +201,7 @@ (XLMRobertaConfig, XLMRobertaModelWithHeads), (RobertaConfig, RobertaModelWithHeads), (BertConfig, BertModelWithHeads), + (DistilBertConfig, DistilBertModelWithHeads), ] ) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index f0eec46691..226c89ea11 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -36,7 +36,6 @@ BertSelfOutputAdaptersMixin, ) from .adapter_model_mixin import ModelWithHeadsAdaptersMixin -from .adapter_utils import parse_adapter_names from .configuration_bert import BertConfig from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer @@ -783,13 +782,7 @@ def forward( embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds ) - - # TODO: Currently no fusion over invertible adapters, takes only very first language adapter position - if adapter_names is not None and len(adapter_names) > 0: - adapter_names = parse_adapter_names(adapter_names) - - if adapter_names[0][0] in self.invertible_lang_adapters: - embedding_output = self.invertible_lang_adapters[adapter_names[0][0]](embedding_output, rev=False) + embedding_output = self.invertible_adapters_forward(embedding_output, adapter_names=adapter_names) encoder_outputs = self.encoder( embedding_output, @@ -965,13 +958,8 @@ def forward( ) sequence_output, pooled_output = outputs[:2] - if adapter_names is not None: - adapter_names = parse_adapter_names(adapter_names) - language = adapter_names[0][0] - else: - language = None prediction_scores, seq_relationship_score = self.cls( - sequence_output, pooled_output, inv_lang_adapter=self.bert.get_invertible_lang_adapter(language), + sequence_output, pooled_output, inv_lang_adapter=self.bert.get_invertible_lang_adapter(adapter_names), ) outputs = (prediction_scores, seq_relationship_score,) + outputs[ @@ -1079,15 +1067,8 @@ def forward( ) sequence_output = outputs[0] - # TODO assume that first elem is language - if adapter_names is not None: - adapter_names = parse_adapter_names(adapter_names) - language = adapter_names[0][0] - else: - language = None - prediction_scores = self.cls( - sequence_output, inv_lang_adapter=self.bert.get_invertible_lang_adapter(language), + sequence_output, inv_lang_adapter=self.bert.get_invertible_lang_adapter(adapter_names), ) outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here diff --git a/src/transformers/modeling_distilbert.py b/src/transformers/modeling_distilbert.py index cf93c0b1c7..0d8123ca07 100644 --- a/src/transformers/modeling_distilbert.py +++ b/src/transformers/modeling_distilbert.py @@ -18,7 +18,6 @@ """ -import copy import logging import math import warnings @@ -29,6 +28,13 @@ from torch.nn import CrossEntropyLoss from .activations import gelu +from .adapter_distilbert import ( + DistilBertModelAdaptersMixin, + DistilBertModelHeadsMixin, + DistilBertTransfomerBlockAdaptersMixin, + DistilBertTransformerAdaptersMixin, +) +from .adapter_model_mixin import ModelWithHeadsAdaptersMixin from .configuration_distilbert import DistilBertConfig from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer @@ -209,9 +215,10 @@ def forward(self, input): return x -class TransformerBlock(nn.Module): +class TransformerBlock(DistilBertTransfomerBlockAdaptersMixin, nn.Module): def __init__(self, config): super().__init__() + self.config = config assert config.dim % config.n_heads == 0 @@ -221,7 +228,9 @@ def __init__(self, config): self.ffn = FFN(config) self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) - def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False): + self._init_adapter_modules() + + def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False, adapter_names=None): """ Parameters ---------- @@ -244,11 +253,15 @@ def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False): else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples assert type(sa_output) == tuple sa_output = sa_output[0] - sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim) + sa_output = self.attention_adapters.adapters_forward( + sa_output, x, adapter_names=adapter_names + ) # (bs, seq_length, dim) # Feed Forward Network ffn_output = self.ffn(sa_output) # (bs, seq_length, dim) - ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim) + ffn_output = self.output_adapters.adapters_forward( + ffn_output, sa_output, adapter_names=adapter_names + ) # (bs, seq_length, dim) output = (ffn_output,) if output_attentions: @@ -256,15 +269,23 @@ def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False): return output -class Transformer(nn.Module): +class Transformer(DistilBertTransformerAdaptersMixin, nn.Module): def __init__(self, config): super().__init__() + self.config = config self.n_layers = config.n_layers - layer = TransformerBlock(config) - self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)]) + self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) - def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False): + def forward( + self, + x, + attn_mask=None, + head_mask=None, + output_attentions=False, + output_hidden_states=False, + adapter_names=None, + ): """ Parameters ---------- @@ -293,7 +314,11 @@ def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False, ou all_hidden_states = all_hidden_states + (hidden_state,) layer_outputs = layer_module( - x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i], output_attentions=output_attentions + x=hidden_state, + attn_mask=attn_mask, + head_mask=head_mask[i], + output_attentions=output_attentions, + adapter_names=adapter_names, ) hidden_state = layer_outputs[-1] @@ -386,13 +411,15 @@ def _init_weights(self, module): "The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.", DISTILBERT_START_DOCSTRING, ) -class DistilBertModel(DistilBertPreTrainedModel): +class DistilBertModel(DistilBertModelAdaptersMixin, DistilBertPreTrainedModel): def __init__(self, config): super().__init__(config) self.embeddings = Embeddings(config) # Embeddings self.transformer = Transformer(config) # Encoder + self._init_adapter_modules() + self.init_weights() def get_input_embeddings(self): @@ -419,6 +446,7 @@ def forward( inputs_embeds=None, output_attentions=None, output_hidden_states=None, + adapter_names=None, ): r""" Return: @@ -441,6 +469,11 @@ def forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + # override the default active adapters with those passed in the method call + adapter_names = adapter_names or self.active_adapters + # some warnings if we don't use available adapters + if not adapter_names and self.has_adapters(): + logger.warning("There are adapters available but none are passed to model.forward") if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") @@ -461,12 +494,15 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim) + inputs_embeds = self.invertible_adapters_forward(inputs_embeds, adapter_names=adapter_names) + tfmr_output = self.transformer( x=inputs_embeds, attn_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + adapter_names=adapter_names, ) hidden_state = tfmr_output[0] output = (hidden_state,) + tfmr_output[1:] @@ -474,10 +510,47 @@ def forward( return output # last-layer hidden-state, (all hidden_states), (all attentions) +@add_start_docstrings( + """DistilBert Model transformer with the option to add multiple flexible heads on top.""", + DISTILBERT_START_DOCSTRING, +) +class DistilBertModelWithHeads(DistilBertModelHeadsMixin, DistilBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.distilbert = DistilBertModel(config) + + self._init_head_modules() + + self.init_weights() + + @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + labels=None, + adapter_names=None, + head=None, + ): + distilbert_output = self.distilbert( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + adapter_names=adapter_names, + ) + + outputs = self.forward_head(distilbert_output, head_name=head, attention_mask=attention_mask, labels=labels,) + + return outputs + + @add_start_docstrings( """DistilBert Model with a `masked language modeling` head on top. """, DISTILBERT_START_DOCSTRING, ) -class DistilBertForMaskedLM(DistilBertPreTrainedModel): +class DistilBertForMaskedLM(ModelWithHeadsAdaptersMixin, DistilBertPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -504,6 +577,7 @@ def forward( labels=None, output_attentions=None, output_hidden_states=None, + adapter_names=None, **kwargs ): r""" @@ -549,18 +623,21 @@ def forward( inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + adapter_names=adapter_names, ) hidden_states = dlbrt_output[0] # (bs, seq_length, dim) prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim) prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim) prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim) + prediction_logits = self.distilbert.invertible_adapters_forward( + prediction_logits, adapter_names=adapter_names, rev=True + ) prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size) outputs = (prediction_logits,) + dlbrt_output[1:] if labels is not None: mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1)) outputs = (mlm_loss,) + outputs - return outputs # (mlm_loss), prediction_logits, (all hidden_states), (all attentions) @@ -569,7 +646,7 @@ def forward( the pooled output) e.g. for GLUE tasks. """, DISTILBERT_START_DOCSTRING, ) -class DistilBertForSequenceClassification(DistilBertPreTrainedModel): +class DistilBertForSequenceClassification(ModelWithHeadsAdaptersMixin, DistilBertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels @@ -592,6 +669,7 @@ def forward( labels=None, output_attentions=None, output_hidden_states=None, + adapter_names=None, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): @@ -626,6 +704,7 @@ def forward( inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + adapter_names=adapter_names, ) hidden_state = distilbert_output[0] # (bs, seq_len, dim) pooled_output = hidden_state[:, 0] # (bs, dim) @@ -652,7 +731,7 @@ def forward( the hidden-states output to compute `span start logits` and `span end logits`). """, DISTILBERT_START_DOCSTRING, ) -class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): +class DistilBertForQuestionAnswering(ModelWithHeadsAdaptersMixin, DistilBertPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -675,6 +754,7 @@ def forward( end_positions=None, output_attentions=None, output_hidden_states=None, + adapter_names=None, ): r""" start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): @@ -713,6 +793,7 @@ def forward( inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + adapter_names=adapter_names, ) hidden_states = distilbert_output[0] # (bs, max_query_len, dim) @@ -748,7 +829,7 @@ def forward( the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, DISTILBERT_START_DOCSTRING, ) -class DistilBertForTokenClassification(DistilBertPreTrainedModel): +class DistilBertForTokenClassification(ModelWithHeadsAdaptersMixin, DistilBertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels @@ -770,6 +851,7 @@ def forward( labels=None, output_attentions=None, output_hidden_states=None, + adapter_names=None, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): @@ -802,6 +884,7 @@ def forward( inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + adapter_names=adapter_names, ) sequence_output = outputs[0] diff --git a/src/transformers/modeling_roberta.py b/src/transformers/modeling_roberta.py index 01d97c2db5..dddaff361a 100644 --- a/src/transformers/modeling_roberta.py +++ b/src/transformers/modeling_roberta.py @@ -25,7 +25,6 @@ from .adapter_bert import BertModelHeadsMixin from .adapter_model_mixin import ModelWithHeadsAdaptersMixin -from .adapter_utils import parse_adapter_names from .configuration_roberta import RobertaConfig from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable from .modeling_bert import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel, gelu @@ -297,15 +296,8 @@ def forward( adapter_names=adapter_names, ) sequence_output = outputs[0] - - # TODO: Currently no fusion over invertible adapters, takes only very first language adapter position - if adapter_names is not None and len(adapter_names) > 0: - adapter_names = parse_adapter_names(adapter_names) - language = adapter_names[0][0] - else: - language = None prediction_scores = self.lm_head( - sequence_output, inv_lang_adapter=self.roberta.get_invertible_lang_adapter(language), + sequence_output, inv_lang_adapter=self.roberta.get_invertible_lang_adapter(adapter_names), ) outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here diff --git a/tests/test_adapter_fusion_loading.py b/tests/test_adapter_fusion_loading.py index 8bee1a1b9b..4b5de5b9c9 100644 --- a/tests/test_adapter_fusion_loading.py +++ b/tests/test_adapter_fusion_loading.py @@ -9,6 +9,7 @@ ADAPTERFUSION_CONFIG_MAP, AdapterType, BertModel, + DistilBertModel, PfeifferConfig, RobertaModel, XLMRobertaModel, @@ -29,7 +30,7 @@ def create_twin_models(model1): @require_torch class AdapterFusionModelTest(unittest.TestCase): - model_classes = [BertModel, RobertaModel, XLMRobertaModel] + model_classes = [BertModel, RobertaModel, XLMRobertaModel, DistilBertModel] def test_add_adapter_fusion(self): diff --git a/tests/test_adapter_loading.py b/tests/test_adapter_loading.py index cdbce1fdbd..0d2698ce0e 100644 --- a/tests/test_adapter_loading.py +++ b/tests/test_adapter_loading.py @@ -9,6 +9,8 @@ AdapterType, BertModel, BertModelWithHeads, + DistilBertModel, + DistilBertModelWithHeads, RobertaModel, RobertaModelWithHeads, XLMRobertaModel, @@ -31,7 +33,7 @@ def create_twin_models(model_class): @require_torch class AdapterModelTest(unittest.TestCase): - model_classes = [BertModel, RobertaModel, XLMRobertaModel] + model_classes = [BertModel, RobertaModel, XLMRobertaModel, DistilBertModel] def test_add_adapter(self): for model_class in self.model_classes: @@ -124,7 +126,7 @@ def test_model_config_serialization(self): @require_torch class PredictionHeadModelTest(unittest.TestCase): - model_classes = [BertModelWithHeads, RobertaModelWithHeads] + model_classes = [BertModelWithHeads, RobertaModelWithHeads, DistilBertModelWithHeads] def run_prediction_head_test(self, model, compare_model, head_name, input_shape=(1, 128), output_shape=(1, 2)): with tempfile.TemporaryDirectory() as temp_dir: