diff --git a/docs/classes/models/electra.rst b/docs/classes/models/electra.rst new file mode 100644 index 000000000..d67a96d8d --- /dev/null +++ b/docs/classes/models/electra.rst @@ -0,0 +1,32 @@ +ELECTRA +====== + +The ELECTRA model was proposed in the paper `ELECTRA: Pre-training Text Encoders as Discriminators Rather Than +Generators `__. ELECTRA is a new pretraining approach which trains two +transformer models: the generator and the discriminator. The generator's role is to replace tokens in a sequence, and +is therefore trained as a masked language model. The discriminator, which is the model we're interested in, tries to +identify which tokens were replaced by the generator in the sequence. + +The abstract from the paper is the following: + +*Masked language modeling (MLM) pretraining methods such as BERT corrupt the input by replacing some tokens with [MASK] +and then train a model to reconstruct the original tokens. While they produce good results when transferred to +downstream NLP tasks, they generally require large amounts of compute to be effective. As an alternative, we propose a +more sample-efficient pretraining task called replaced token detection. Instead of masking the input, our approach +corrupts it by replacing some tokens with plausible alternatives sampled from a small generator network. Then, instead +of training a model that predicts the original identities of the corrupted tokens, we train a discriminative model that +predicts whether each token in the corrupted input was replaced by a generator sample or not. Thorough experiments +demonstrate this new pretraining task is more efficient than MLM because the task is defined over all input tokens +rather than just the small subset that was masked out. As a result, the contextual representations learned by our +approach substantially outperform the ones learned by BERT given the same model size, data, and compute. The gains are +particularly strong for small models; for example, we train a model on one GPU for 4 days that outperforms GPT (trained +using 30x more compute) on the GLUE natural language understanding benchmark. Our approach also works well at scale, +where it performs comparably to RoBERTa and XLNet while using less than 1/4 of their compute and outperforms them when +using the same amount of compute.* + +ElectraAdapterModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: adapters.ElectraAdapterModel + :members: + :inherited-members: ElectraPreTrainedModel diff --git a/docs/classes/models/llama.rst b/docs/classes/models/llama.rst index 3ce6f2c84..c7fffe183 100644 --- a/docs/classes/models/llama.rst +++ b/docs/classes/models/llama.rst @@ -1,7 +1,7 @@ LLaMA ----------------------------------------------------------------------------------------------------------------------- -The LLaMA model was proposed in `LLaMA: Open and Efficient Foundation Language Models` by +The LLaMA model was proposed in `LLaMA: Open and Efficient Foundation Language Models `__ by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample. It is a collection of foundation language models ranging from 7B to 65B parameters. diff --git a/docs/index.rst b/docs/index.rst index 8ce684cf2..323913cf7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -66,6 +66,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model classes/models/deberta classes/models/deberta_v2 classes/models/distilbert + classes/models/electra classes/models/encoderdecoder classes/models/gpt2 classes/models/gptj diff --git a/docs/model_overview.md b/docs/model_overview.md index 70f71c264..8198ea64d 100644 --- a/docs/model_overview.md +++ b/docs/model_overview.md @@ -21,6 +21,7 @@ The table below further shows which model architectures support which adaptation | [DeBERTa](classes/models/deberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [DeBERTa-v2](classes/models/debertaV2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [DistilBERT](classes/models/distilbert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [Electra](classes/models/electra.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [Encoder Decoder](classes/models/encoderdecoder.html) | (*) | (*) | (*) | (*) | (*) | (*) | | | [GPT-2](classes/models/gpt2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [GPT-J](classes/models/gptj.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index 8cec0ee23..dd6803705 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -99,6 +99,7 @@ "models.deberta": ["DebertaAdapterModel"], "models.deberta_v2": ["DebertaV2AdapterModel"], "models.distilbert": ["DistilBertAdapterModel"], + "models.electra": ["ElectraAdapterModel"], "models.gpt2": ["GPT2AdapterModel"], "models.gptj": ["GPTJAdapterModel"], "models.llama": ["LlamaAdapterModel"], @@ -199,6 +200,7 @@ from .models.deberta import DebertaAdapterModel from .models.deberta_v2 import DebertaV2AdapterModel from .models.distilbert import DistilBertAdapterModel + from .models.electra import ElectraAdapterModel from .models.gpt2 import GPT2AdapterModel from .models.gptj import GPTJAdapterModel from .models.llama import LlamaAdapterModel diff --git a/src/adapters/composition.py b/src/adapters/composition.py index e3ee04925..b4ecab817 100644 --- a/src/adapters/composition.py +++ b/src/adapters/composition.py @@ -135,6 +135,7 @@ def __init__( "xlm-roberta", "bert-generation", "llama", + "electra", "xmod", ], } diff --git a/src/adapters/head_utils.py b/src/adapters/head_utils.py index d1425f25d..b4f9ba437 100644 --- a/src/adapters/head_utils.py +++ b/src/adapters/head_utils.py @@ -598,6 +598,70 @@ }, "layers": ["lm_head"], }, + "ElectraForTokenClassification": { + "config": { + "head_type": "tagging", + "layers": 1, + "activation_function": None, + }, + "layers": [None, "classifier"], + }, + "ElectraForSequenceClassification": { + "config": { + "head_type": "classification", + "layers": 2, + "activation_function": "gelu", + "bias": True, + }, + "layers": [None, "classifier.dense", None, None, "classifier.out_proj"], + }, + "ElectraForQuestionAnswering": { + "config": { + "head_type": "question_answering", + "layers": 1, + "activation_function": None, + }, + "layers": [None, "qa_outputs"], + }, + "ElectraForMultipleChoice": { + "config": { + "head_type": "multiple_choice", + "layers": 2, + "activation_function": "gelu", + "use_pooler": False, + }, + "layers": [None, "sequence_summary.summary", None, None, "classifier"], + }, + "ElectraForMaskedLM": { + "config": { + "head_type": "masked_lm", + "layers": 2, + "activation_function": "gelu", + "layer_norm": True, + "bias": True, + }, + "layers": [ + "generator_predictions.dense", + None, + "generator_predictions.LayerNorm", + "generator_lm_head", + ], + }, + "ElectraForCausalLM": { + "config": { + "head_type": "causal_lm", + "layers": 2, + "activation_function": "gelu", + "layer_norm": True, + "bias": True, + }, + "layers": [ + "generator_predictions.dense", + None, + "generator_predictions.LayerNorm", + "generator_lm_head", + ], + }, } diff --git a/src/adapters/models/__init__.py b/src/adapters/models/__init__.py index 4df8e98f3..b0e59abf1 100644 --- a/src/adapters/models/__init__.py +++ b/src/adapters/models/__init__.py @@ -43,6 +43,8 @@ "CLIPModel": CLIPModelAdaptersMixin, "CLIPTextModelWithProjection": CLIPTextModelAdaptersMixin, "CLIPVisionModelWithProjection": CLIPVisionModelAdaptersMixin, + "ElectraLayer": BertLayerAdaptersMixin, + "ElectraModel": BertModelAdaptersMixin, "MBartEncoder": BartEncoderAdaptersMixin, "MBartDecoder": BartDecoderAdaptersMixin, "MBartDecoderWrapper": BartDecoderWrapperAdaptersMixin, diff --git a/src/adapters/models/auto/adapter_model.py b/src/adapters/models/auto/adapter_model.py index 75a154c6c..5ff84de48 100644 --- a/src/adapters/models/auto/adapter_model.py +++ b/src/adapters/models/auto/adapter_model.py @@ -18,6 +18,7 @@ ("deberta", "DebertaAdapterModel"), ("deberta-v2", "DebertaV2AdapterModel"), ("distilbert", "DistilBertAdapterModel"), + ("electra", "ElectraAdapterModel"), ("gpt2", "GPT2AdapterModel"), ("gptj", "GPTJAdapterModel"), ("llama", "LlamaAdapterModel"), diff --git a/src/adapters/models/electra/__init__.py b/src/adapters/models/electra/__init__.py new file mode 100644 index 000000000..bbf0bdbc8 --- /dev/null +++ b/src/adapters/models/electra/__init__.py @@ -0,0 +1,39 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The Adapter-Hub Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import _LazyModule + + +_import_structure = { + "adapter_model": ["ElectraAdapterModel"], +} + + +if TYPE_CHECKING: + from .adapter_model import ElectraAdapterModel + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + ) diff --git a/src/adapters/models/electra/adapter_model.py b/src/adapters/models/electra/adapter_model.py new file mode 100644 index 000000000..2d7994d3a --- /dev/null +++ b/src/adapters/models/electra/adapter_model.py @@ -0,0 +1,243 @@ +from transformers.models.electra.modeling_electra import ( + ELECTRA_INPUTS_DOCSTRING, + ELECTRA_START_DOCSTRING, + ElectraModel, + ElectraPreTrainedModel, +) +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward + +from ...context import AdapterSetup +from ...heads import ( + BertStyleMaskedLMHead, + BiaffineParsingHead, + CausalLMHead, + ClassificationHead, + ModelWithFlexibleHeadsAdaptersMixin, + MultiLabelClassificationHead, + MultipleChoiceHead, + QuestionAnsweringHead, + TaggingHead, +) +from ...model_mixin import EmbeddingAdaptersWrapperMixin +from ...wrappers import init + + +@add_start_docstrings( + """Electra Model transformer with the option to add multiple flexible heads on top.""", + ELECTRA_START_DOCSTRING, +) +class ElectraAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, ElectraPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.electra = ElectraModel(config) + init(self.electra) + + self._init_head_modules() + + self.init_weights() + + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, + **kwargs + ): + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.electra( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + ) + + head_inputs = outputs + + if head or AdapterSetup.get_context_head_setup() or self.active_head: + head_outputs = self.forward_head( + head_inputs, + head_name=head, + attention_mask=attention_mask, + return_dict=return_dict, + **kwargs, + ) + return head_outputs + else: + # in case no head is used just return the output of the base model (including pooler output) + return outputs + + # Copied from BertLMHeadModel + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False), + } + + head_types = { + "classification": ClassificationHead, + "multilabel_classification": MultiLabelClassificationHead, + "tagging": TaggingHead, + "multiple_choice": MultipleChoiceHead, + "question_answering": QuestionAnsweringHead, + "dependency_parsing": BiaffineParsingHead, + "masked_lm": BertStyleMaskedLMHead, + "causal_lm": CausalLMHead, + } + + def add_classification_head( + self, + head_name, + num_labels=2, + layers=2, + activation_function="tanh", + overwrite_ok=False, + multilabel=False, + id2label=None, + use_pooler=False, + ): + """ + Adds a sequence classification head on top of the model. + + Args: + head_name (str): The name of the head. + num_labels (int, optional): Number of classification labels. Defaults to 2. + layers (int, optional): Number of layers. Defaults to 2. + activation_function (str, optional): Activation function. Defaults to 'tanh'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + multilabel (bool, optional): Enable multilabel classification setup. Defaults to False. + """ + + if multilabel: + head = MultiLabelClassificationHead( + self, head_name, num_labels, layers, activation_function, id2label, use_pooler + ) + else: + head = ClassificationHead(self, head_name, num_labels, layers, activation_function, id2label, use_pooler) + self.add_prediction_head(head, overwrite_ok) + + def add_multiple_choice_head( + self, + head_name, + num_choices=2, + layers=2, + activation_function="tanh", + overwrite_ok=False, + id2label=None, + use_pooler=False, + ): + """ + Adds a multiple choice head on top of the model. + + Args: + head_name (str): The name of the head. + num_choices (int, optional): Number of choices. Defaults to 2. + layers (int, optional): Number of layers. Defaults to 2. + activation_function (str, optional): Activation function. Defaults to 'tanh'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + """ + head = MultipleChoiceHead(self, head_name, num_choices, layers, activation_function, id2label, use_pooler) + self.add_prediction_head(head, overwrite_ok) + + def add_tagging_head( + self, head_name, num_labels=2, layers=1, activation_function="tanh", overwrite_ok=False, id2label=None + ): + """ + Adds a token classification head on top of the model. + + Args: + head_name (str): The name of the head. + num_labels (int, optional): Number of classification labels. Defaults to 2. + layers (int, optional): Number of layers. Defaults to 1. + activation_function (str, optional): Activation function. Defaults to 'tanh'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + """ + head = TaggingHead(self, head_name, num_labels, layers, activation_function, id2label) + self.add_prediction_head(head, overwrite_ok) + + def add_qa_head( + self, head_name, num_labels=2, layers=1, activation_function="tanh", overwrite_ok=False, id2label=None + ): + head = QuestionAnsweringHead(self, head_name, num_labels, layers, activation_function, id2label) + self.add_prediction_head(head, overwrite_ok) + + def add_dependency_parsing_head(self, head_name, num_labels=2, overwrite_ok=False, id2label=None): + """ + Adds a biaffine dependency parsing head on top of the model. The parsing head uses the architecture described + in "Is Supervised Syntactic Parsing Beneficial for Language Understanding? An Empirical Investigation" (Glavaš + & Vulić, 2021) (https://arxiv.org/pdf/2008.06788.pdf). + + Args: + head_name (str): The name of the head. + num_labels (int, optional): Number of labels. Defaults to 2. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + id2label (dict, optional): Mapping from label ids to labels. Defaults to None. + """ + head = BiaffineParsingHead(self, head_name, num_labels, id2label) + self.add_prediction_head(head, overwrite_ok) + + def add_masked_lm_head(self, head_name, activation_function="gelu", overwrite_ok=False): + """ + Adds a masked language modeling head on top of the model. + + Args: + head_name (str): The name of the head. + activation_function (str, optional): Activation function. Defaults to 'gelu'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + """ + head = BertStyleMaskedLMHead(self, head_name, activation_function=activation_function) + self.add_prediction_head(head, overwrite_ok=overwrite_ok) + + def add_causal_lm_head(self, head_name, activation_function="gelu", overwrite_ok=False): + """ + Adds a causal language modeling head on top of the model. + + Args: + head_name (str): The name of the head. + activation_function (str, optional): Activation function. Defaults to 'gelu'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + """ + head = CausalLMHead( + self, head_name, layers=2, activation_function=activation_function, layer_norm=True, bias=True + ) + self.add_prediction_head(head, overwrite_ok=overwrite_ok) diff --git a/src/adapters/models/electra/modeling_electra.py b/src/adapters/models/electra/modeling_electra.py new file mode 100644 index 000000000..0412b4dc1 --- /dev/null +++ b/src/adapters/models/electra/modeling_electra.py @@ -0,0 +1,134 @@ +import math +from typing import Optional, Tuple + +import torch +from torch import nn + +from transformers.models.electra.modeling_electra import ElectraOutput, ElectraSelfAttention, ElectraSelfOutput + +from ...composition import adjust_tensors_for_parallel +from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin + + +class ElectraSelfAttentionWithAdapters(BertSelfAttentionAdaptersMixin, ElectraSelfAttention): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + key_layer, value_layer, attention_mask = self.prefix_tuning( + key_layer, value_layer, hidden_states, attention_mask + ) + (query_layer,) = adjust_tensors_for_parallel(key_layer, query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class ElectraSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, ElectraSelfOutput): + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + return hidden_states + + +class ElectraOutputWithAdapters(BertOutputAdaptersMixin, ElectraOutput): + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + return hidden_states diff --git a/tests_adapters/models/test_electra.py b/tests_adapters/models/test_electra.py new file mode 100644 index 000000000..642eeb0c0 --- /dev/null +++ b/tests_adapters/models/test_electra.py @@ -0,0 +1,12 @@ +# flake8: noqa: F403,F405 +from adapters import ElectraAdapterModel +from hf_transformers.tests.models.electra.test_modeling_electra import * +from transformers.testing_utils import require_torch + +from .base import AdapterModelTesterMixin + + +@require_torch +class ElectraAdapterModelTest(AdapterModelTesterMixin, ElectraModelTester): + all_model_classes = (ElectraAdapterModel,) + fx_compatible = False diff --git a/tests_adapters/test_electra.py b/tests_adapters/test_electra.py new file mode 100644 index 000000000..5566e7be0 --- /dev/null +++ b/tests_adapters/test_electra.py @@ -0,0 +1,64 @@ +import unittest + +from tests_adapters.methods.test_config_union import ConfigUnionAdapterTest +from transformers import ElectraConfig +from transformers.testing_utils import require_torch + +from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin +from .methods import ( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, +) +from .test_adapter import AdapterTestBase, make_config +from .test_adapter_backward_compability import CompabilityTestMixin +from .test_adapter_conversion import ModelClassConversionTestMixin +from .test_adapter_embeddings import EmbeddingTestMixin +from .test_adapter_fusion_common import AdapterFusionModelTestMixin +from .test_adapter_heads import PredictionHeadModelTestMixin + + +class ElectraAdapterTestBase(AdapterTestBase): + config_class = ElectraConfig + config = make_config( + ElectraConfig, + # vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + ) + tokenizer_name = "google/electra-base-generator" + + +@require_torch +class ElectraAdapterTest( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, + EmbeddingTestMixin, + AdapterFusionModelTestMixin, + CompabilityTestMixin, + PredictionHeadModelTestMixin, + ParallelAdapterInferenceTestMixin, + ParallelTrainingMixin, + ConfigUnionAdapterTest, + ElectraAdapterTestBase, + unittest.TestCase, +): + pass + + +@require_torch +class ElectraClassConversionTest( + ModelClassConversionTestMixin, + ElectraAdapterTestBase, + unittest.TestCase, +): + pass