From 5b617358d5951e847941c7d09366bf2952b965b5 Mon Sep 17 00:00:00 2001 From: CoderPat Date: Sun, 25 Apr 2021 11:19:23 +0100 Subject: [PATCH 01/11] add electra model to flax --- src/transformers/__init__.py | 24 + src/transformers/models/electra/__init__.py | 26 +- .../models/electra/modeling_flax_electra.py | 673 ++++++++++++++++++ tests/test_modeling_flax_electra.py | 131 ++++ 4 files changed, 853 insertions(+), 1 deletion(-) create mode 100644 src/transformers/models/electra/modeling_flax_electra.py create mode 100644 tests/test_modeling_flax_electra.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3e72488be2ab44..3d5be12156ee52 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1390,6 +1390,20 @@ "FlaxBertPreTrainedModel", ] ) + _import_structure["models.electra"].extend( + [ + "FlaxElectraForMaskedLM", + "FlaxElectraForMultipleChoice", + "FlaxElectraForNextSentencePrediction", + "FlaxElectraForPreTraining", + "FlaxElectraForQuestionAnswering", + "FlaxElectraForSequenceClassification", + "FlaxElectraForTokenClassification", + "FlaxElectraModel", + "FlaxElectraPreTrainedModel", + ] + ) + _import_structure["models.roberta"].append("FlaxRobertaModel") else: from .utils import dummy_flax_objects @@ -2551,6 +2565,16 @@ FlaxBertModel, FlaxBertPreTrainedModel, ) + from .models.electra import ( + FlaxElectraForMaskedLM, + FlaxElectraForMultipleChoice, + FlaxElectraForPreTraining, + FlaxElectraForQuestionAnswering, + FlaxElectraForSequenceClassification, + FlaxElectraForTokenClassification, + FlaxElectraModel, + FlaxElectraPreTrainedModel, + ) from .models.roberta import FlaxRobertaModel else: # Import the same objects as dummies to get them in the namespace. diff --git a/src/transformers/models/electra/__init__.py b/src/transformers/models/electra/__init__.py index 121bed2f8a6d20..0c8dabeecdd133 100644 --- a/src/transformers/models/electra/__init__.py +++ b/src/transformers/models/electra/__init__.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING -from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_available, is_torch_available +from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_available, is_torch_available, is_flax_available _import_structure = { @@ -56,6 +56,18 @@ "TFElectraPreTrainedModel", ] +if is_flax_available(): + _import_structure["modeling_flax_electra"] = [ + "FlaxElectraForMaskedLM", + "FlaxElectraForMultipleChoice", + "FlaxElectraForPreTraining", + "FlaxElectraForQuestionAnswering", + "FlaxElectraForSequenceClassification", + "FlaxElectraForTokenClassification", + "FlaxElectraModel", + "FlaxElectraPreTrainedModel", + ] + if TYPE_CHECKING: from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig @@ -91,6 +103,18 @@ TFElectraPreTrainedModel, ) + if is_flax_available(): + from .modeling_flax_electra import ( + FlaxElectraForMaskedLM, + FlaxElectraForMultipleChoice, + FlaxElectraForPreTraining, + FlaxElectraForQuestionAnswering, + FlaxElectraForSequenceClassification, + FlaxElectraForTokenClassification, + FlaxElectraModel, + FlaxElectraPreTrainedModel, + ) + else: import importlib import os diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py new file mode 100644 index 00000000000000..822f8825044c57 --- /dev/null +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -0,0 +1,673 @@ +# coding=utf-8 +# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Tuple + +import numpy as np + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict +from flax.linen import dot_product_attention +from jax import lax +from jax.random import PRNGKey + +from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, overwrite_call_docstring +from ...utils import logging +from .configuration_electra import ElectraConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "ElectraConfig" +_TOKENIZER_FOR_DOC = "ElectraTokenizer" + + +ELECTRA_START_DOCSTRING = r""" + + This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the + generic methods the library implements for all its model (such as downloading, saving and converting weights from + PyTorch models) + + This model is also a Flax Linen `flax.nn.Module + `__ subclass. Use it as a regular Flax + Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - `Just-In-Time (JIT) compilation `__ + - `Automatic Differentiation `__ + - `Vectorization `__ + - `Parallelization `__ + + Parameters: + config (:class:`~transformers.ElectraConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +ELECTRA_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.ElectraTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`__ + position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + +""" + + +class FlaxElectraEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.word_embeddings = nn.Embed( + self.config.vocab_size, + self.config.embedding_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.position_embeddings = nn.Embed( + self.config.max_position_embeddings, + self.config.embedding_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.token_type_embeddings = nn.Embed( + self.config.type_vocab_size, + self.config.embedding_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): + # Embed + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + position_embeds = self.position_embeddings(position_ids.astype("i4")) + token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) + + # Sum all embeddings + hidden_states = inputs_embeds + token_type_embeddings + position_embeds + + # Layer Norm + hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +class FlaxElectraSelfAttention(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + ) + + def __call__(self, hidden_states, attention_mask, deterministic=True): + head_dim = self.config.hidden_size // self.config.num_attention_heads + + query_states = self.query(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + value_states = self.value(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + key_states = self.key(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, -1e10).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_output = dot_product_attention( + query_states, + key_states, + value_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + return attn_output.reshape(attn_output.shape[:2] + (-1,)) + + +class FlaxElectraSelfOutput(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class FlaxElectraAttention(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.self = FlaxElectraSelfAttention(self.config, dtype=self.dtype) + self.output = FlaxElectraSelfOutput(self.config, dtype=self.dtype) + + def __call__(self, hidden_states, attention_mask, deterministic=True): + # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) + # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable + # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) + attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic) + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) + return hidden_states + + +class FlaxElectraIntermediate(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class FlaxElectraOutput(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states, attention_output, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + attention_output) + return hidden_states + + +class FlaxElectraLayer(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.attention = FlaxElectraAttention(self.config, dtype=self.dtype) + self.intermediate = FlaxElectraIntermediate(self.config, dtype=self.dtype) + self.output = FlaxElectraOutput(self.config, dtype=self.dtype) + + def __call__(self, hidden_states, attention_mask, deterministic: bool = True): + attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic) + hidden_states = self.intermediate(attention_output) + hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) + return hidden_states + + +class FlaxElectraLayerCollection(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) + ] + + def __call__(self, hidden_states, attention_mask, deterministic: bool = True): + for i, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic) + return hidden_states + + +class FlaxElectraEncoder(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layer = FlaxElectraLayerCollection(self.config, dtype=self.dtype) + + def __call__(self, hidden_states, attention_mask, deterministic: bool = True): + return self.layer(hidden_states, attention_mask, deterministic=deterministic) + + +class FlaxElectraPooler(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + cls_hidden_state = hidden_states[:, 0] + cls_hidden_state = self.dense(cls_hidden_state) + return nn.gelu(cls_hidden_state) + + +class FlaxElectraGeneratorPredictions(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = ACT2FN[self.config.hidden_act](hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class FlaxElectraDiscriminatorPredictions(nn.Module): + """Prediction module for the discriminator, made up of two dense layers.""" + + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) + self.dense_prediction = nn.Dense(1, dtype=self.dtype) + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = ACT2FN[self.config.hidden_act](hidden_states) + hidden_states = self.dense_prediction(hidden_states).squeeze(-1) + return hidden_states + + +class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ElectraConfig + base_model_prefix = "electra" + module_class: nn.Module = None + + def __init__( + self, config: ElectraConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + token_type_ids = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) + attention_mask = jnp.ones_like(input_ids) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"] + + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + params: dict = None, + dropout_rng: PRNGKey = None, + train: bool = False, + ): + # init input tensors if not passed + if token_type_ids is None: + token_type_ids = jnp.ones_like(input_ids) + + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(token_type_ids, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + rngs=rngs, + ) + + +class FlaxElectraModule(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + add_pooling_layer: bool = True + + def setup(self): + self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype) + if self.config.embedding_size != self.config.hidden_size: + self.embeddings_project = nn.Dense(self.config.hidden_size) + self.encoder = FlaxElectraEncoder(self.config, dtype=self.dtype) + self.pooler = FlaxElectraPooler(self.config, dtype=self.dtype) + + def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): + embeddings = self.embeddings( + input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic + ) + if hasattr(self, "embeddings_project"): + embeddings = self.embeddings_project(embeddings) + + hidden_states = self.encoder(embeddings, attention_mask, deterministic=deterministic) + + if not self.add_pooling_layer: + return hidden_states + + pooled = self.pooler(hidden_states) + return hidden_states, pooled + +@add_start_docstrings( + "The bare Electra Model transformer outputting raw hidden-states without any specific head on top.", + ELECTRA_START_DOCSTRING, +) +class FlaxElectraModel(FlaxElectraPreTrainedModel): + module_class = FlaxElectraModule + +class FlaxElectraForMaskedLMModule(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config) + # TODO: should we have the option to include shared embeddings here + self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype) + + def __call__( + self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + ): + hidden_states = self.electra( + input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic + ) + prediction_scores = self.generator_predictions(hidden_states) + prediction_scores = self.generator_lm_head(prediction_scores) + return prediction_scores + +@add_start_docstrings("""Electra Model with a `language modeling` head on top. """, ELECTRA_START_DOCSTRING) +class FlaxElectraForMaskedLM(FlaxElectraPreTrainedModel): + module_class = FlaxElectraForMaskedLMModule + + +class FlaxElectraForPreTrainingModule(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.discriminator_predictions = FlaxElectraDiscriminatorPredictions(config=self.config, dtype=self.dtype) + + def __call__( + self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + ): + # Model + hidden_states = self.electra( + input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic + ) + + logits = self.discriminator_predictions(hidden_states) + return logits + +@add_start_docstrings( + """ + Electra model with a binary classification head on top as used during pretraining for identifying generated tokens. + + It is recommended to load the discriminator checkpoint into that model. + """, + ELECTRA_START_DOCSTRING, +) +class FlaxElectraForPreTraining(FlaxElectraPreTrainedModel): + module_class = FlaxElectraForPreTrainingModule + + +class FlaxElectraForTokenClassificationModule(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.dropout = nn.Dropout(self.config.hidden_dropout_prob) + self.classifier = nn.Dense(self.config.num_labels) + + def __call__( + self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + ): + # Model + hidden_states = self.electra( + input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic + ) + + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + logits = self.classifier(hidden_states) + return logits + +@add_start_docstrings( + """ + Electra model with a token classification head on top. + + Both the discriminator and generator may be loaded into this model. + """, + ELECTRA_START_DOCSTRING, +) +class FlaxElectraForTokenClassification(FlaxElectraPreTrainedModel): + module_class = FlaxElectraForTokenClassificationModule + + +class FlaxElectraForMultipleChoiceModule(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.classifier = nn.Dense(1, dtype=self.dtype) + + def __call__( + self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + ): + num_choices = input_ids.shape[1] + input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None + token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None + position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None + + # Model + _, pooled_output = self.electra( + input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic + ) + + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.reshape(-1, num_choices) + + return (reshaped_logits,) + + +@add_start_docstrings( + """ + ELECTRA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + ELECTRA_START_DOCSTRING, +) +class FlaxElectraForMultipleChoice(FlaxElectraPreTrainedModel): + module_class = FlaxElectraForMultipleChoiceModule + + +class FlaxElectraForQuestionAnsweringModule(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + ): + # Model + hidden_states = self.electra(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic) + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + return (start_logits, end_logits) + +@add_start_docstrings( + """ + ELECTRA Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ELECTRA_START_DOCSTRING, +) +class FlaxElectraForQuestionAnswering(FlaxElectraPreTrainedModel): + module_class = FlaxElectraForQuestionAnsweringModule + + +class FlaxElectraForSequenceClassificationModule(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.out_proj = nn.Dense( + self.config.num_labels, + dtype=self.dtype, + ) + + def __call__( + self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + ): + # Model + _, pooled_output = self.electra( + input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic + ) + + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.out_proj(pooled_output) + + return (logits,) + +@add_start_docstrings( + """ + Electra Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + ELECTRA_START_DOCSTRING, +) +class FlaxElectraForSequenceClassification(FlaxElectraPreTrainedModel): + module_class = FlaxElectraForSequenceClassificationModule \ No newline at end of file diff --git a/tests/test_modeling_flax_electra.py b/tests/test_modeling_flax_electra.py new file mode 100644 index 00000000000000..666f483c769efa --- /dev/null +++ b/tests/test_modeling_flax_electra.py @@ -0,0 +1,131 @@ +import unittest + +import numpy as np + +from transformers import ElectraConfig, is_flax_available +from transformers.testing_utils import require_flax, slow + +from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask + + +if is_flax_available(): + from transformers.models.electra.modeling_flax_electra import ( + FlaxElectraForMaskedLM, + FlaxElectraForPreTraining, + FlaxElectraForMultipleChoice, + FlaxElectraForQuestionAnswering, + FlaxElectraForTokenClassification, + FlaxElectraForSequenceClassification, + FlaxElectraModel, + ) + + +class FlaxElectraModelTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_attention_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + embedding_size=24, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_choices=4 + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.embedding_size = embedding_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_choices = num_choices + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + config = ElectraConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + embedding_size=self.embedding_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + ) + + return config, input_ids, token_type_ids, attention_mask + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, token_type_ids, attention_mask = config_and_inputs + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask} + return config, inputs_dict + + +@require_flax +class FlaxElectraModelTest(FlaxModelTesterMixin, unittest.TestCase): + + all_model_classes = ( + ( + FlaxElectraModel, + FlaxElectraForMaskedLM, + FlaxElectraForPreTraining, + FlaxElectraForTokenClassification, + FlaxElectraForQuestionAnswering, + FlaxElectraForMultipleChoice, + FlaxElectraForSequenceClassification + ) if is_flax_available() else () + ) + + def setUp(self): + self.model_tester = FlaxElectraModelTester(self) + + @slow + def test_model_from_pretrained(self): + for model_class_name in self.all_model_classes: + if model_class_name == FlaxElectraForMaskedLM: + model = model_class_name.from_pretrained("google/electra-small-generator", from_pt=True) + else: + model = model_class_name.from_pretrained("google/electra-small-discriminator", from_pt=True) + outputs = model(np.ones((1, 1))) + self.assertIsNotNone(outputs) From 343d61eb3e987dfb5515cafd40ddd95e1b2962b4 Mon Sep 17 00:00:00 2001 From: CoderPat Date: Sun, 25 Apr 2021 11:34:30 +0100 Subject: [PATCH 02/11] Remove Electra Next Sentence Prediction model added by mistake --- src/transformers/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3d5be12156ee52..391bc92a24e2b0 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1394,7 +1394,6 @@ [ "FlaxElectraForMaskedLM", "FlaxElectraForMultipleChoice", - "FlaxElectraForNextSentencePrediction", "FlaxElectraForPreTraining", "FlaxElectraForQuestionAnswering", "FlaxElectraForSequenceClassification", From 1fc055b14cb6dfaec1fa55f88c7c486254e3d195 Mon Sep 17 00:00:00 2001 From: CoderPat Date: Sun, 25 Apr 2021 17:22:01 +0100 Subject: [PATCH 03/11] fix parameter sharing and loosen equality threshold --- src/transformers/modeling_flax_utils.py | 131 +++++++++++++++++- .../models/electra/modeling_flax_electra.py | 118 ++++++++-------- tests/test_modeling_flax_common.py | 6 +- 3 files changed, 197 insertions(+), 58 deletions(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 1331b3ba399788..4a4f0782465ac1 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -16,11 +16,14 @@ import os from functools import partial from pickle import UnpicklingError -from typing import Dict, Set, Tuple, Union +from typing import Dict, Set, Tuple, Union, Callable + +import numpy as np import flax.linen as nn import jax import jax.numpy as jnp +from jax import lax from flax.core.frozen_dict import FrozenDict, unfreeze from flax.serialization import from_bytes, to_bytes from flax.traverse_util import flatten_dict, unflatten_dict @@ -424,6 +427,132 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub=F url = self._push_to_hub(save_files=saved_files, **kwargs) logger.info(f"Model pushed to the hub in this commit: {url}") +class SequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config (:class:`~transformers.PretrainedConfig`): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (:obj:`str`) -- The method to use to make this summary. Accepted values are: + + - :obj:`"last"` -- Take the last token hidden state (like XLNet) + - :obj:`"first"` -- Take the first token hidden state (like Bert) + - :obj:`"mean"` -- Take the mean of all tokens hidden states + - :obj:`"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - :obj:`"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to + :obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`). + - **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the + output, another string or :obj:`None` will add no activation. + - **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and + activation. + - **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and + activation. + """ + config: PretrainedConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + identity = lambda x, deterministic=True: x + + self.summary_type = getattr(self.config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = identity + if hasattr(self.config, "summary_use_proj") and self.config.summary_use_proj: + if hasattr(self.config, "summary_proj_to_labels") and self.config.summary_proj_to_labels and self.config.num_labels > 0: + num_classes = self.config.num_labels + else: + num_classes = self.config.hidden_size + self.summary = nn.Dense( + num_classes, + dtype=self.dtype + ) + + activation_string = getattr(self.config, "summary_activation", None) + self.activation = ACT2FN[activation_string] if activation_string else lambda x: x + + self.first_dropout = identity + if hasattr(self.config, "summary_first_dropout") and self.config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(self.config.summary_first_dropout) + + self.last_dropout = identity + if hasattr(self.config, "summary_last_dropout") and self.config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(self.config.summary_last_dropout) + + def __call__(self, hidden_states, cls_index = None, deterministic: bool = True): + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (:obj:`jnp.array` of shape :obj:`[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (:obj:`jnp.array` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are optional leading dimensions of :obj:`hidden_states`, `optional`): + Used if :obj:`summary_type == "cls_index"` and takes the last token of the sequence as classification + token. + + Returns: + :obj:`jnp.array`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = jnp.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=jnp.long, + ) + else: + # TODO: + raise NotImplementedError + #cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + #cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output, deterministic=deterministic) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output, deterministic=deterministic) + + return output + + +class TiedDense(nn.Module): + embedding_size: int + dtype: jnp.dtype = jnp.float32 + precision = None + bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros + + def setup(self): + bias = self.param("bias", self.bias_init, (self.embedding_size,)) + self.bias = jnp.asarray(bias, dtype=self.dtype) + + def __call__(self, x, kernel): + y = lax.dot_general( + x, + kernel, + (((x.ndim - 1,), (0,)), ((), ())), + precision=self.precision, + ) + return y + self.bias + def overwrite_call_docstring(model_class, docstring): # copy __call__ function to be sure docstring is changed only for this function diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index 822f8825044c57..1d99d45ec7ed30 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -26,7 +26,13 @@ from jax.random import PRNGKey from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, overwrite_call_docstring +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + SequenceSummary, + TiedDense, + overwrite_call_docstring +) from ...utils import logging from .configuration_electra import ElectraConfig @@ -326,30 +332,13 @@ def __call__(self, hidden_states, attention_mask, deterministic: bool = True): return self.layer(hidden_states, attention_mask, deterministic=deterministic) -class FlaxElectraPooler(nn.Module): - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), - dtype=self.dtype, - ) - - def __call__(self, hidden_states): - cls_hidden_state = hidden_states[:, 0] - cls_hidden_state = self.dense(cls_hidden_state) - return nn.gelu(cls_hidden_state) - - class FlaxElectraGeneratorPredictions(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 def setup(self): - self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype) def __call__(self, hidden_states): hidden_states = self.dense(hidden_states) @@ -443,14 +432,12 @@ def __call__( class FlaxElectraModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation - add_pooling_layer: bool = True def setup(self): self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype) if self.config.embedding_size != self.config.hidden_size: self.embeddings_project = nn.Dense(self.config.hidden_size) self.encoder = FlaxElectraEncoder(self.config, dtype=self.dtype) - self.pooler = FlaxElectraPooler(self.config, dtype=self.dtype) def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): embeddings = self.embeddings( @@ -460,12 +447,8 @@ def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, dete embeddings = self.embeddings_project(embeddings) hidden_states = self.encoder(embeddings, attention_mask, deterministic=deterministic) + return (hidden_states,) - if not self.add_pooling_layer: - return hidden_states - - pooled = self.pooler(hidden_states) - return hidden_states, pooled @add_start_docstrings( "The bare Electra Model transformer outputting raw hidden-states without any specific head on top.", @@ -474,25 +457,35 @@ def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, dete class FlaxElectraModel(FlaxElectraPreTrainedModel): module_class = FlaxElectraModule + class FlaxElectraForMaskedLMModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 def setup(self): - self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config) - # TODO: should we have the option to include shared embeddings here - self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype) + if self.config.tie_word_embeddings: + self.generator_lm_head = TiedDense(self.config.vocab_size, dtype=self.dtype) + else: + self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype) def __call__( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True ): hidden_states = self.electra( input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic - ) + )[0] prediction_scores = self.generator_predictions(hidden_states) - prediction_scores = self.generator_lm_head(prediction_scores) - return prediction_scores + + if self.config.tie_word_embeddings: + shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T) + else: + prediction_scores = self.generator_lm_head(prediction_scores) + + return (prediction_scores,) + @add_start_docstrings("""Electra Model with a `language modeling` head on top. """, ELECTRA_START_DOCSTRING) class FlaxElectraForMaskedLM(FlaxElectraPreTrainedModel): @@ -504,7 +497,7 @@ class FlaxElectraForPreTrainingModule(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): - self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.discriminator_predictions = FlaxElectraDiscriminatorPredictions(config=self.config, dtype=self.dtype) def __call__( @@ -513,10 +506,11 @@ def __call__( # Model hidden_states = self.electra( input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic - ) + )[0] logits = self.discriminator_predictions(hidden_states) - return logits + return (logits,) + @add_start_docstrings( """ @@ -535,7 +529,7 @@ class FlaxElectraForTokenClassificationModule(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): - self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.dropout = nn.Dropout(self.config.hidden_dropout_prob) self.classifier = nn.Dense(self.config.num_labels) @@ -545,11 +539,12 @@ def __call__( # Model hidden_states = self.electra( input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic - ) + )[0] hidden_states = self.dropout(hidden_states, deterministic=deterministic) logits = self.classifier(hidden_states) - return logits + return (logits,) + @add_start_docstrings( """ @@ -569,7 +564,7 @@ class FlaxElectraForMultipleChoiceModule(nn.Module): def setup(self): self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.sequence_summary = SequenceSummary(config=self.config, dtype=self.dtype) self.classifier = nn.Dense(1, dtype=self.dtype) def __call__( @@ -582,11 +577,10 @@ def __call__( position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None # Model - _, pooled_output = self.electra( + hidden_states = self.electra( input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic - ) - - pooled_output = self.dropout(pooled_output, deterministic=deterministic) + )[0] + pooled_output = self.sequence_summary(hidden_states, deterministic=deterministic) logits = self.classifier(pooled_output) reshaped_logits = logits.reshape(-1, num_choices) @@ -610,7 +604,7 @@ class FlaxElectraForQuestionAnsweringModule(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): - self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__( @@ -626,6 +620,7 @@ def __call__( return (start_logits, end_logits) + @add_start_docstrings( """ ELECTRA Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear @@ -637,31 +632,46 @@ class FlaxElectraForQuestionAnswering(FlaxElectraPreTrainedModel): module_class = FlaxElectraForQuestionAnsweringModule +class FlaxElectraClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) + self.dropout = nn.Dropout(self.config.hidden_dropout_prob) + self.out_proj = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic: bool = True): + x = hidden_states[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x, deterministic=deterministic) + x = self.dense(x) + x = ACT2FN["gelu"](x) # although BERT uses tanh here, it seems Electra authors used gelu here + x = self.dropout(x, deterministic=deterministic) + x = self.out_proj(x) + return x + + class FlaxElectraForSequenceClassificationModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.out_proj = nn.Dense( - self.config.num_labels, - dtype=self.dtype, - ) + self.classifier = FlaxElectraClassificationHead(config=self.config, dtype=self.dtype) def __call__( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True ): # Model - _, pooled_output = self.electra( + hidden_states = self.electra( input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic - ) - - pooled_output = self.dropout(pooled_output, deterministic=deterministic) - logits = self.out_proj(pooled_output) + )[0] + logits = self.classifier(hidden_states, deterministic=deterministic) return (logits,) + @add_start_docstrings( """ Electra Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 8d5ca111fd9a7a..28de1cc20c12df 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -111,7 +111,7 @@ def test_equivalence_pt_to_flax(self): fx_outputs = fx_model(**prepared_inputs_dict) self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3) + self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3) with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) @@ -122,7 +122,7 @@ def test_equivalence_pt_to_flax(self): len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" ) for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-3) + self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3) @is_pt_flax_cross_test def test_equivalence_flax_to_pt(self): @@ -152,7 +152,7 @@ def test_equivalence_flax_to_pt(self): fx_outputs = fx_model(**prepared_inputs_dict) self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3) + self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3) with tempfile.TemporaryDirectory() as tmpdirname: fx_model.save_pretrained(tmpdirname) From 53d1835565408fdbe25d0cd5692476279c4e4145 Mon Sep 17 00:00:00 2001 From: CoderPat Date: Sun, 25 Apr 2021 20:21:59 +0100 Subject: [PATCH 04/11] fix styling issues --- src/transformers/modeling_flax_utils.py | 27 +++++--- src/transformers/models/electra/__init__.py | 49 ++++--------- .../models/electra/modeling_flax_electra.py | 38 +++++------ src/transformers/utils/dummy_flax_objects.py | 68 +++++++++++++++++++ tests/test_modeling_flax_electra.py | 20 +++--- 5 files changed, 126 insertions(+), 76 deletions(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 4a4f0782465ac1..87c8f9bcd2488b 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -16,17 +16,17 @@ import os from functools import partial from pickle import UnpicklingError -from typing import Dict, Set, Tuple, Union, Callable +from typing import Callable, Dict, Set, Tuple, Union import numpy as np import flax.linen as nn import jax import jax.numpy as jnp -from jax import lax from flax.core.frozen_dict import FrozenDict, unfreeze from flax.serialization import from_bytes, to_bytes from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax from jax.random import PRNGKey from .configuration_utils import PretrainedConfig @@ -58,6 +58,10 @@ } +def identity(x, **kwargs): + return x + + class FlaxPreTrainedModel(PushToHubMixin): r""" Base class for all models. @@ -427,6 +431,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub=F url = self._push_to_hub(save_files=saved_files, **kwargs) logger.info(f"Model pushed to the hub in this commit: {url}") + class SequenceSummary(nn.Module): r""" Compute a single vector summary of a sequence hidden states. @@ -458,7 +463,6 @@ class SequenceSummary(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): - identity = lambda x, deterministic=True: x self.summary_type = getattr(self.config, "summary_type", "last") if self.summary_type == "attn": @@ -469,14 +473,15 @@ def setup(self): self.summary = identity if hasattr(self.config, "summary_use_proj") and self.config.summary_use_proj: - if hasattr(self.config, "summary_proj_to_labels") and self.config.summary_proj_to_labels and self.config.num_labels > 0: + if ( + hasattr(self.config, "summary_proj_to_labels") + and self.config.summary_proj_to_labels + and self.config.num_labels > 0 + ): num_classes = self.config.num_labels else: num_classes = self.config.hidden_size - self.summary = nn.Dense( - num_classes, - dtype=self.dtype - ) + self.summary = nn.Dense(num_classes, dtype=self.dtype) activation_string = getattr(self.config, "summary_activation", None) self.activation = ACT2FN[activation_string] if activation_string else lambda x: x @@ -489,7 +494,7 @@ def setup(self): if hasattr(self.config, "summary_last_dropout") and self.config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(self.config.summary_last_dropout) - def __call__(self, hidden_states, cls_index = None, deterministic: bool = True): + def __call__(self, hidden_states, cls_index=None, deterministic: bool = True): """ Compute a single vector summary of a sequence hidden states. @@ -519,8 +524,8 @@ def __call__(self, hidden_states, cls_index = None, deterministic: bool = True): else: # TODO: raise NotImplementedError - #cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) - #cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + # cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) elif self.summary_type == "attn": diff --git a/src/transformers/models/electra/__init__.py b/src/transformers/models/electra/__init__.py index 0c8dabeecdd133..0c51e285079c7f 100644 --- a/src/transformers/models/electra/__init__.py +++ b/src/transformers/models/electra/__init__.py @@ -18,7 +18,13 @@ from typing import TYPE_CHECKING -from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_available, is_torch_available, is_flax_available +from ...file_utils import ( + _BaseLazyModule, + is_flax_available, + is_tf_available, + is_tokenizers_available, + is_torch_available, +) _import_structure = { @@ -70,50 +76,19 @@ if TYPE_CHECKING: - from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig - from .tokenization_electra import ElectraTokenizer + pass if is_tokenizers_available(): - from .tokenization_electra_fast import ElectraTokenizerFast + pass if is_torch_available(): - from .modeling_electra import ( - ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST, - ElectraForMaskedLM, - ElectraForMultipleChoice, - ElectraForPreTraining, - ElectraForQuestionAnswering, - ElectraForSequenceClassification, - ElectraForTokenClassification, - ElectraModel, - ElectraPreTrainedModel, - load_tf_weights_in_electra, - ) + pass if is_tf_available(): - from .modeling_tf_electra import ( - TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST, - TFElectraForMaskedLM, - TFElectraForMultipleChoice, - TFElectraForPreTraining, - TFElectraForQuestionAnswering, - TFElectraForSequenceClassification, - TFElectraForTokenClassification, - TFElectraModel, - TFElectraPreTrainedModel, - ) + pass if is_flax_available(): - from .modeling_flax_electra import ( - FlaxElectraForMaskedLM, - FlaxElectraForMultipleChoice, - FlaxElectraForPreTraining, - FlaxElectraForQuestionAnswering, - FlaxElectraForSequenceClassification, - FlaxElectraForTokenClassification, - FlaxElectraModel, - FlaxElectraPreTrainedModel, - ) + pass else: import importlib diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index 1d99d45ec7ed30..f731ffdb00354b 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Tuple - -import numpy as np +from typing import Tuple import flax.linen as nn import jax @@ -26,13 +24,7 @@ from jax.random import PRNGKey from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - SequenceSummary, - TiedDense, - overwrite_call_docstring -) +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, SequenceSummary, TiedDense from ...utils import logging from .configuration_electra import ElectraConfig @@ -375,7 +367,12 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): module_class: nn.Module = None def __init__( - self, config: ElectraConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs + self, + config: ElectraConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) @@ -532,9 +529,9 @@ def setup(self): self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.dropout = nn.Dropout(self.config.hidden_dropout_prob) self.classifier = nn.Dense(self.config.num_labels) - + def __call__( - self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True ): # Model hidden_states = self.electra( @@ -566,9 +563,9 @@ def setup(self): self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.sequence_summary = SequenceSummary(config=self.config, dtype=self.dtype) self.classifier = nn.Dense(1, dtype=self.dtype) - + def __call__( - self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True ): num_choices = input_ids.shape[1] input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None @@ -611,7 +608,9 @@ def __call__( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True ): # Model - hidden_states = self.electra(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic) + hidden_states = self.electra( + input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic + ) logits = self.qa_outputs(hidden_states) start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) @@ -634,6 +633,7 @@ class FlaxElectraForQuestionAnswering(FlaxElectraPreTrainedModel): class FlaxElectraClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" + config: ElectraConfig dtype: jnp.dtype = jnp.float32 @@ -674,10 +674,10 @@ def __call__( @add_start_docstrings( """ - Electra Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled - output) e.g. for GLUE tasks. + Electra Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. """, ELECTRA_START_DOCSTRING, ) class FlaxElectraForSequenceClassification(FlaxElectraPreTrainedModel): - module_class = FlaxElectraForSequenceClassificationModule \ No newline at end of file + module_class = FlaxElectraForSequenceClassificationModule diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 1b1e61b6298693..c6314db1548e43 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -180,6 +180,74 @@ def from_pretrained(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxElectraForMaskedLM: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxElectraForMultipleChoice: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxElectraForPreTraining: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxElectraForQuestionAnswering: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxElectraForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxElectraForTokenClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxElectraModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxElectraPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxRobertaModel: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) diff --git a/tests/test_modeling_flax_electra.py b/tests/test_modeling_flax_electra.py index 666f483c769efa..bcc1cb1c2fa9fc 100644 --- a/tests/test_modeling_flax_electra.py +++ b/tests/test_modeling_flax_electra.py @@ -11,11 +11,11 @@ if is_flax_available(): from transformers.models.electra.modeling_flax_electra import ( FlaxElectraForMaskedLM, - FlaxElectraForPreTraining, FlaxElectraForMultipleChoice, + FlaxElectraForPreTraining, FlaxElectraForQuestionAnswering, - FlaxElectraForTokenClassification, FlaxElectraForSequenceClassification, + FlaxElectraForTokenClassification, FlaxElectraModel, ) @@ -43,7 +43,7 @@ def __init__( type_vocab_size=16, type_sequence_label_size=2, initializer_range=0.02, - num_choices=4 + num_choices=4, ): self.parent = parent self.batch_size = batch_size @@ -107,14 +107,16 @@ class FlaxElectraModelTest(FlaxModelTesterMixin, unittest.TestCase): all_model_classes = ( ( - FlaxElectraModel, - FlaxElectraForMaskedLM, - FlaxElectraForPreTraining, - FlaxElectraForTokenClassification, + FlaxElectraModel, + FlaxElectraForMaskedLM, + FlaxElectraForPreTraining, + FlaxElectraForTokenClassification, FlaxElectraForQuestionAnswering, FlaxElectraForMultipleChoice, - FlaxElectraForSequenceClassification - ) if is_flax_available() else () + FlaxElectraForSequenceClassification, + ) + if is_flax_available() + else () ) def setUp(self): From 8d4621e2d46f72eec34b4efb3e2a6307d882527c Mon Sep 17 00:00:00 2001 From: CoderPat Date: Sun, 25 Apr 2021 20:29:03 +0100 Subject: [PATCH 05/11] add mistaken removen imports --- src/transformers/models/electra/__init__.py | 41 ++++++++++++++++++--- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/electra/__init__.py b/src/transformers/models/electra/__init__.py index 0c51e285079c7f..729c35ea58516e 100644 --- a/src/transformers/models/electra/__init__.py +++ b/src/transformers/models/electra/__init__.py @@ -76,19 +76,50 @@ if TYPE_CHECKING: - pass + from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig + from .tokenization_electra import ElectraTokenizer if is_tokenizers_available(): - pass + from .tokenization_electra_fast import ElectraTokenizerFast if is_torch_available(): - pass + from .modeling_electra import ( + ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST, + ElectraForMaskedLM, + ElectraForMultipleChoice, + ElectraForPreTraining, + ElectraForQuestionAnswering, + ElectraForSequenceClassification, + ElectraForTokenClassification, + ElectraModel, + ElectraPreTrainedModel, + load_tf_weights_in_electra, + ) if is_tf_available(): - pass + from .modeling_tf_electra import ( + TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST, + TFElectraForMaskedLM, + TFElectraForMultipleChoice, + TFElectraForPreTraining, + TFElectraForQuestionAnswering, + TFElectraForSequenceClassification, + TFElectraForTokenClassification, + TFElectraModel, + TFElectraPreTrainedModel, + ) if is_flax_available(): - pass + from .modeling_flax_electra import ( + FlaxElectraForMaskedLM, + FlaxElectraForMultipleChoice, + FlaxElectraForPreTraining, + FlaxElectraForQuestionAnswering, + FlaxElectraForSequenceClassification, + FlaxElectraForTokenClassification, + FlaxElectraModel, + FlaxElectraPreTrainedModel, + ) else: import importlib From f536d7774f10b7cec652c18e89f76476eb8710a5 Mon Sep 17 00:00:00 2001 From: CoderPat Date: Sun, 25 Apr 2021 21:18:35 +0100 Subject: [PATCH 06/11] fix electra table --- docs/source/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 25a2a380431e7a..ce12ec7ad37bf7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -292,7 +292,7 @@ TensorFlow and/or Flax. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | DistilBERT | ✅ | ✅ | ✅ | ✅ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ -| ELECTRA | ✅ | ✅ | ✅ | ✅ | ❌ | +| ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | Encoder decoder | ❌ | ❌ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ From f41e4db62662913e36a069910a97770da62b159f Mon Sep 17 00:00:00 2001 From: CoderPat Date: Sun, 25 Apr 2021 22:57:35 +0100 Subject: [PATCH 07/11] Add FlaxElectra to automodels and fixe docs --- docs/source/model_doc/electra.rst | 49 +++++++++++++++++++ .../models/auto/modeling_flax_auto.py | 18 ++++++- .../models/electra/modeling_flax_electra.py | 8 ++- 3 files changed, 73 insertions(+), 2 deletions(-) diff --git a/docs/source/model_doc/electra.rst b/docs/source/model_doc/electra.rst index a332b1fd88e65e..cf15ccc7cb4cbf 100644 --- a/docs/source/model_doc/electra.rst +++ b/docs/source/model_doc/electra.rst @@ -185,3 +185,52 @@ TFElectraForQuestionAnswering .. autoclass:: transformers.TFElectraForQuestionAnswering :members: call + + +FlaxElectraModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxElectraModel + :members: __call__ + + +FlaxElectraForPreTraining +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxElectraForPreTraining + :members: __call__ + + +FlaxElectraForMaskedLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxElectraForMaskedLM + :members: __call__ + + +FlaxElectraForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxElectraForSequenceClassification + :members: __call__ + + +FlaxElectraForMultipleChoice +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxElectraForMultipleChoice + :members: __call__ + + +FlaxElectraForTokenClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxElectraForTokenClassification + :members: __call__ + + +FlaxElectraForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxElectraForQuestionAnswering + :members: __call__ diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 042612d0a52909..b80c7f6f738e09 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -28,9 +28,18 @@ FlaxBertForTokenClassification, FlaxBertModel, ) +from ..electra.modeling_flax_electra import ( + FlaxElectraForMaskedLM, + FlaxElectraForMultipleChoice, + FlaxElectraForPreTraining, + FlaxElectraForQuestionAnswering, + FlaxElectraForSequenceClassification, + FlaxElectraForTokenClassification, + FlaxElectraModel, +) from ..roberta.modeling_flax_roberta import FlaxRobertaModel from .auto_factory import auto_class_factory -from .configuration_auto import BertConfig, RobertaConfig +from .configuration_auto import BertConfig, ElectraConfig, RobertaConfig logger = logging.get_logger(__name__) @@ -41,6 +50,7 @@ # Base model mapping (RobertaConfig, FlaxRobertaModel), (BertConfig, FlaxBertModel), + (ElectraConfig, FlaxElectraModel), ] ) @@ -48,6 +58,7 @@ [ # Model for pre-training mapping (BertConfig, FlaxBertForPreTraining), + (ElectraConfig, FlaxElectraForPreTraining), ] ) @@ -55,6 +66,7 @@ [ # Model for Masked LM mapping (BertConfig, FlaxBertForMaskedLM), + (ElectraConfig, FlaxElectraForMaskedLM), ] ) @@ -62,6 +74,7 @@ [ # Model for Sequence Classification mapping (BertConfig, FlaxBertForSequenceClassification), + (ElectraConfig, FlaxElectraForSequenceClassification), ] ) @@ -69,6 +82,7 @@ [ # Model for Question Answering mapping (BertConfig, FlaxBertForQuestionAnswering), + (ElectraConfig, FlaxElectraForQuestionAnswering), ] ) @@ -76,6 +90,7 @@ [ # Model for Token Classification mapping (BertConfig, FlaxBertForTokenClassification), + (ElectraConfig, FlaxElectraForTokenClassification), ] ) @@ -83,6 +98,7 @@ [ # Model for Multiple Choice mapping (BertConfig, FlaxBertForMultipleChoice), + (ElectraConfig, FlaxElectraForMultipleChoice), ] ) diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index f731ffdb00354b..814473f981a08d 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -24,7 +24,7 @@ from jax.random import PRNGKey from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, SequenceSummary, TiedDense +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, SequenceSummary, TiedDense, overwrite_call_docstring from ...utils import logging from .configuration_electra import ElectraConfig @@ -596,6 +596,12 @@ class FlaxElectraForMultipleChoice(FlaxElectraPreTrainedModel): module_class = FlaxElectraForMultipleChoiceModule +# adapt docstring slightly for FlaxElectraForMultipleChoice +overwrite_call_docstring( + FlaxElectraForMultipleChoice, ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") +) + + class FlaxElectraForQuestionAnsweringModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 From 9f96419d18267595e7090997c9a1d839b69ca1bb Mon Sep 17 00:00:00 2001 From: CoderPat Date: Mon, 26 Apr 2021 22:15:36 +0100 Subject: [PATCH 08/11] fix issues pointed out the PR --- src/transformers/modeling_flax_utils.py | 25 +--- .../models/electra/modeling_flax_electra.py | 113 +++++++++++++++++- tests/test_modeling_flax_common.py | 2 +- 3 files changed, 110 insertions(+), 30 deletions(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 87c8f9bcd2488b..24f2b2d2da42db 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -16,9 +16,7 @@ import os from functools import partial from pickle import UnpicklingError -from typing import Callable, Dict, Set, Tuple, Union - -import numpy as np +from typing import Dict, Set, Tuple, Union import flax.linen as nn import jax @@ -26,7 +24,6 @@ from flax.core.frozen_dict import FrozenDict, unfreeze from flax.serialization import from_bytes, to_bytes from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax from jax.random import PRNGKey from .configuration_utils import PretrainedConfig @@ -539,26 +536,6 @@ def __call__(self, hidden_states, cls_index=None, deterministic: bool = True): return output -class TiedDense(nn.Module): - embedding_size: int - dtype: jnp.dtype = jnp.float32 - precision = None - bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros - - def setup(self): - bias = self.param("bias", self.bias_init, (self.embedding_size,)) - self.bias = jnp.asarray(bias, dtype=self.dtype) - - def __call__(self, x, kernel): - y = lax.dot_general( - x, - kernel, - (((x.ndim - 1,), (0,)), ((), ())), - precision=self.precision, - ) - return y + self.bias - - def overwrite_call_docstring(model_class, docstring): # copy __call__ function to be sure docstring is changed only for this function model_class.__call__ = copy_func(model_class.__call__) diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index 814473f981a08d..f9bf2abcf57717 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Callable, Tuple + +import numpy as np import flax.linen as nn import jax @@ -24,7 +26,7 @@ from jax.random import PRNGKey from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, SequenceSummary, TiedDense, overwrite_call_docstring +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, overwrite_call_docstring from ...utils import logging from .configuration_electra import ElectraConfig @@ -136,6 +138,7 @@ def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, dete return hidden_states +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Electra class FlaxElectraSelfAttention(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -207,6 +210,7 @@ def __call__(self, hidden_states, attention_mask, deterministic=True): return attn_output.reshape(attn_output.shape[:2] + (-1,)) +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Electra class FlaxElectraSelfOutput(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -227,6 +231,7 @@ def __call__(self, hidden_states, input_tensor, deterministic: bool = True): return hidden_states +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Electra class FlaxElectraAttention(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 @@ -244,6 +249,7 @@ def __call__(self, hidden_states, attention_mask, deterministic=True): return hidden_states +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Electra class FlaxElectraIntermediate(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -262,6 +268,7 @@ def __call__(self, hidden_states): return hidden_states +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Electra class FlaxElectraOutput(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -282,6 +289,7 @@ def __call__(self, hidden_states, attention_output, deterministic: bool = True): return hidden_states +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Electra class FlaxElectraLayer(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -298,6 +306,7 @@ def __call__(self, hidden_states, attention_mask, deterministic: bool = True): return hidden_states +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Electra class FlaxElectraLayerCollection(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -313,6 +322,7 @@ def __call__(self, hidden_states, attention_mask, deterministic: bool = True): return hidden_states +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Electra class FlaxElectraEncoder(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -455,6 +465,26 @@ class FlaxElectraModel(FlaxElectraPreTrainedModel): module_class = FlaxElectraModule +class FlaxElectraTiedDense(nn.Module): + embedding_size: int + dtype: jnp.dtype = jnp.float32 + precision = None + bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros + + def setup(self): + bias = self.param("bias", self.bias_init, (self.embedding_size,)) + self.bias = jnp.asarray(bias, dtype=self.dtype) + + def __call__(self, x, kernel): + y = lax.dot_general( + x, + kernel, + (((x.ndim - 1,), (0,)), ((), ())), + precision=self.precision, + ) + return y + self.bias + + class FlaxElectraForMaskedLMModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 @@ -463,7 +493,7 @@ def setup(self): self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config) if self.config.tie_word_embeddings: - self.generator_lm_head = TiedDense(self.config.vocab_size, dtype=self.dtype) + self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype) else: self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype) @@ -555,13 +585,86 @@ class FlaxElectraForTokenClassification(FlaxElectraPreTrainedModel): module_class = FlaxElectraForTokenClassificationModule +def identity(x, **kwargs): + return x + + +class FlaxElectraSequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config (:class:`~transformers.PretrainedConfig`): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to + :obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`). + - **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the + output, another string or :obj:`None` will add no activation. + - **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and + activation. + - **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and + activation. + """ + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.summary = identity + if hasattr(self.config, "summary_use_proj") and self.config.summary_use_proj: + if ( + hasattr(self.config, "summary_proj_to_labels") + and self.config.summary_proj_to_labels + and self.config.num_labels > 0 + ): + num_classes = self.config.num_labels + else: + num_classes = self.config.hidden_size + self.summary = nn.Dense(num_classes, dtype=self.dtype) + + activation_string = getattr(self.config, "summary_activation", None) + self.activation = ACT2FN[activation_string] if activation_string else lambda x: x + + self.first_dropout = identity + if hasattr(self.config, "summary_first_dropout") and self.config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(self.config.summary_first_dropout) + + self.last_dropout = identity + if hasattr(self.config, "summary_last_dropout") and self.config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(self.config.summary_last_dropout) + + def __call__(self, hidden_states, cls_index=None, deterministic: bool = True): + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (:obj:`jnp.array` of shape :obj:`[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (:obj:`jnp.array` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are optional leading dimensions of :obj:`hidden_states`, `optional`): + Used if :obj:`summary_type == "cls_index"` and takes the last token of the sequence as classification + token. + + Returns: + :obj:`jnp.array`: The summary of the sequence hidden states. + """ + # NOTE: this doest "first" type summary always + output = hidden_states[:, 0] + output = self.first_dropout(output, deterministic=deterministic) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output, deterministic=deterministic) + return output + + class FlaxElectraForMultipleChoiceModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) - self.sequence_summary = SequenceSummary(config=self.config, dtype=self.dtype) + self.sequence_summary = FlaxElectraSequenceSummary(config=self.config, dtype=self.dtype) self.classifier = nn.Dense(1, dtype=self.dtype) def __call__( @@ -652,7 +755,7 @@ def __call__(self, hidden_states, deterministic: bool = True): x = hidden_states[:, 0, :] # take token (equiv. to [CLS]) x = self.dropout(x, deterministic=deterministic) x = self.dense(x) - x = ACT2FN["gelu"](x) # although BERT uses tanh here, it seems Electra authors used gelu here + x = ACT2FN["gelu"](x) # although BERT uses tanh here, it seems Electra authors used gelu x = self.dropout(x, deterministic=deterministic) x = self.out_proj(x) return x diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 28de1cc20c12df..a807cb637bb1ca 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -152,7 +152,7 @@ def test_equivalence_flax_to_pt(self): fx_outputs = fx_model(**prepared_inputs_dict) self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3) + self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3) with tempfile.TemporaryDirectory() as tmpdirname: fx_model.save_pretrained(tmpdirname) From 2d19a8c2cf97d2d401181ee8660b7913360ab53e Mon Sep 17 00:00:00 2001 From: CoderPat Date: Tue, 4 May 2021 12:04:31 +0100 Subject: [PATCH 09/11] fix flax electra to comply with latest changes --- .../models/electra/modeling_flax_electra.py | 452 ++++++++++++++++-- tests/test_modeling_flax_electra.py | 4 +- 2 files changed, 405 insertions(+), 51 deletions(-) diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index f9bf2abcf57717..66ef9a51871ebc 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -13,30 +13,73 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Tuple +from dataclasses import dataclass +from typing import Callable, Optional, Tuple import numpy as np import flax.linen as nn import jax import jax.numpy as jnp +import jaxlib.xla_extension as jax_xla from flax.core.frozen_dict import FrozenDict from flax.linen import dot_product_attention from jax import lax from jax.random import PRNGKey -from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, overwrite_call_docstring +from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxMaskedLMOutput, + FlaxMultipleChoiceModelOutput, + FlaxQuestionAnsweringModelOutput, + FlaxSequenceClassifierOutput, + FlaxTokenClassifierOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) from ...utils import logging from .configuration_electra import ElectraConfig logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "google/electra-small-discriminator" _CONFIG_FOR_DOC = "ElectraConfig" _TOKENIZER_FOR_DOC = "ElectraTokenizer" +@dataclass +class FlaxElectraForPreTrainingOutput(ModelOutput): + """ + Output type of :class:`~transformers.ElectraForPreTraining`. + + Args: + logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: jax_xla.DeviceArray = None + hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None + attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + + ELECTRA_START_DOCSTRING = r""" This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the @@ -165,7 +208,7 @@ def setup(self): kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), ) - def __call__(self, hidden_states, attention_mask, deterministic=True): + def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): head_dim = self.config.hidden_size // self.config.num_attention_heads query_states = self.query(hidden_states).reshape( @@ -207,7 +250,12 @@ def __call__(self, hidden_states, attention_mask, deterministic=True): precision=None, ) - return attn_output.reshape(attn_output.shape[:2] + (-1,)) + outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),) + + # TODO: at the moment it's not possible to retrieve attn_weights from + # dot_product_attention, but should be in the future -> add functionality then + + return outputs # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Electra @@ -240,13 +288,22 @@ def setup(self): self.self = FlaxElectraSelfAttention(self.config, dtype=self.dtype) self.output = FlaxElectraSelfOutput(self.config, dtype=self.dtype) - def __call__(self, hidden_states, attention_mask, deterministic=True): + def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) - attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic) + attn_outputs = self.self( + hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions + ) + attn_output = attn_outputs[0] hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) - return hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += attn_outputs[1] + + return outputs # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Electra @@ -299,11 +356,20 @@ def setup(self): self.intermediate = FlaxElectraIntermediate(self.config, dtype=self.dtype) self.output = FlaxElectraOutput(self.config, dtype=self.dtype) - def __call__(self, hidden_states, attention_mask, deterministic: bool = True): - attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic) + def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False): + attention_outputs = self.attention( + hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions + ) + attention_output = attention_outputs[0] + hidden_states = self.intermediate(attention_output) hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) - return hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + return outputs # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Electra @@ -316,10 +382,40 @@ def setup(self): FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) ] - def __call__(self, hidden_states, attention_mask, deterministic: bool = True): + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, layer in enumerate(self.layers): - hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic) - return hidden_states + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states,) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Electra @@ -330,8 +426,23 @@ class FlaxElectraEncoder(nn.Module): def setup(self): self.layer = FlaxElectraLayerCollection(self.config, dtype=self.dtype) - def __call__(self, hidden_states, attention_mask, deterministic: bool = True): - return self.layer(hidden_states, attention_mask, deterministic=deterministic) + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) class FlaxElectraGeneratorPredictions(nn.Module): @@ -409,7 +520,22 @@ def __call__( params: dict = None, dropout_rng: PRNGKey = None, train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ): + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if output_attentions: + raise NotImplementedError( + "Currently attention scores cannot be returned. Please set `output_attentions` to False for now." + ) + # init input tensors if not passed if token_type_ids is None: token_type_ids = jnp.ones_like(input_ids) @@ -432,6 +558,9 @@ def __call__( jnp.array(token_type_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"), not train, + output_attentions, + output_hidden_states, + return_dict, rngs=rngs, ) @@ -446,15 +575,31 @@ def setup(self): self.embeddings_project = nn.Dense(self.config.hidden_size) self.encoder = FlaxElectraEncoder(self.config, dtype=self.dtype) - def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): embeddings = self.embeddings( input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic ) if hasattr(self, "embeddings_project"): embeddings = self.embeddings_project(embeddings) - hidden_states = self.encoder(embeddings, attention_mask, deterministic=deterministic) - return (hidden_states,) + return self.encoder( + embeddings, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) @add_start_docstrings( @@ -465,6 +610,11 @@ class FlaxElectraModel(FlaxElectraPreTrainedModel): module_class = FlaxElectraModule +append_call_sample_docstring( + FlaxElectraModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC +) + + class FlaxElectraTiedDense(nn.Module): embedding_size: int dtype: jnp.dtype = jnp.float32 @@ -498,11 +648,27 @@ def setup(self): self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype) def __call__( - self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, ): - hidden_states = self.electra( - input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic - )[0] + outputs = self.electra( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] prediction_scores = self.generator_predictions(hidden_states) if self.config.tie_word_embeddings: @@ -511,7 +677,14 @@ def __call__( else: prediction_scores = self.generator_lm_head(prediction_scores) - return (prediction_scores,) + if not return_dict: + return (prediction_scores,) + outputs[1:] + + return FlaxMaskedLMOutput( + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings("""Electra Model with a `language modeling` head on top. """, ELECTRA_START_DOCSTRING) @@ -519,6 +692,11 @@ class FlaxElectraForMaskedLM(FlaxElectraPreTrainedModel): module_class = FlaxElectraForMaskedLMModule +append_call_sample_docstring( + FlaxElectraForMaskedLM, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC +) + + class FlaxElectraForPreTrainingModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 @@ -528,15 +706,39 @@ def setup(self): self.discriminator_predictions = FlaxElectraDiscriminatorPredictions(config=self.config, dtype=self.dtype) def __call__( - self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, ): # Model - hidden_states = self.electra( - input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic - )[0] + outputs = self.electra( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] logits = self.discriminator_predictions(hidden_states) - return (logits,) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxElectraForPreTrainingOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings( @@ -551,6 +753,31 @@ class FlaxElectraForPreTraining(FlaxElectraPreTrainedModel): module_class = FlaxElectraForPreTrainingModule +FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING = """ + Returns: + + Example:: + + >>> from transformers import ElectraTokenizer, FlaxElectraForPreTraining + + >>> tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator') + >>> model = FlaxElectraForPreTraining.from_pretrained('google/electra-small-discriminator') + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits +""" + +overwrite_call_docstring( + FlaxElectraForPreTraining, + ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING, +) +append_replace_return_docstrings( + FlaxElectraForPreTraining, output_type=FlaxElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC +) + + class FlaxElectraForTokenClassificationModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 @@ -561,16 +788,40 @@ def setup(self): self.classifier = nn.Dense(self.config.num_labels) def __call__( - self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, ): # Model - hidden_states = self.electra( - input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic - )[0] + outputs = self.electra( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] hidden_states = self.dropout(hidden_states, deterministic=deterministic) logits = self.classifier(hidden_states) - return (logits,) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxTokenClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings( @@ -585,6 +836,15 @@ class FlaxElectraForTokenClassification(FlaxElectraPreTrainedModel): module_class = FlaxElectraForTokenClassificationModule +append_call_sample_docstring( + FlaxElectraForTokenClassification, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxTokenClassifierOutput, + _CONFIG_FOR_DOC, +) + + def identity(x, **kwargs): return x @@ -668,7 +928,15 @@ def setup(self): self.classifier = nn.Dense(1, dtype=self.dtype) def __call__( - self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, ): num_choices = input_ids.shape[1] input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None @@ -677,15 +945,30 @@ def __call__( position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None # Model - hidden_states = self.electra( - input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic - )[0] + outputs = self.electra( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] pooled_output = self.sequence_summary(hidden_states, deterministic=deterministic) logits = self.classifier(pooled_output) reshaped_logits = logits.reshape(-1, num_choices) - return (reshaped_logits,) + if not return_dict: + return (reshaped_logits,) + outputs[1:] + + return FlaxMultipleChoiceModelOutput( + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings( @@ -703,6 +986,13 @@ class FlaxElectraForMultipleChoice(FlaxElectraPreTrainedModel): overwrite_call_docstring( FlaxElectraForMultipleChoice, ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") ) +append_call_sample_docstring( + FlaxElectraForMultipleChoice, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxMultipleChoiceModelOutput, + _CONFIG_FOR_DOC, +) class FlaxElectraForQuestionAnsweringModule(nn.Module): @@ -714,19 +1004,42 @@ def setup(self): self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__( - self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, ): # Model - hidden_states = self.electra( - input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic + outputs = self.electra( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, ) - + hidden_states = outputs[0] logits = self.qa_outputs(hidden_states) start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) - return (start_logits, end_logits) + if not return_dict: + return (start_logits, end_logits) + outputs[1:] + + return FlaxQuestionAnsweringModelOutput( + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings( @@ -740,6 +1053,15 @@ class FlaxElectraForQuestionAnswering(FlaxElectraPreTrainedModel): module_class = FlaxElectraForQuestionAnsweringModule +append_call_sample_docstring( + FlaxElectraForQuestionAnswering, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxQuestionAnsweringModelOutput, + _CONFIG_FOR_DOC, +) + + class FlaxElectraClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" @@ -770,15 +1092,38 @@ def setup(self): self.classifier = FlaxElectraClassificationHead(config=self.config, dtype=self.dtype) def __call__( - self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, ): # Model - hidden_states = self.electra( - input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic - )[0] + outputs = self.electra( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] logits = self.classifier(hidden_states, deterministic=deterministic) - return (logits,) + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings( @@ -790,3 +1135,12 @@ def __call__( ) class FlaxElectraForSequenceClassification(FlaxElectraPreTrainedModel): module_class = FlaxElectraForSequenceClassificationModule + + +append_call_sample_docstring( + FlaxElectraForSequenceClassification, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxSequenceClassifierOutput, + _CONFIG_FOR_DOC, +) diff --git a/tests/test_modeling_flax_electra.py b/tests/test_modeling_flax_electra.py index bcc1cb1c2fa9fc..2e15f94402bb16 100644 --- a/tests/test_modeling_flax_electra.py +++ b/tests/test_modeling_flax_electra.py @@ -126,8 +126,8 @@ def setUp(self): def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: if model_class_name == FlaxElectraForMaskedLM: - model = model_class_name.from_pretrained("google/electra-small-generator", from_pt=True) + model = model_class_name.from_pretrained("google/electra-small-generator") else: - model = model_class_name.from_pretrained("google/electra-small-discriminator", from_pt=True) + model = model_class_name.from_pretrained("google/electra-small-discriminator") outputs = model(np.ones((1, 1))) self.assertIsNotNone(outputs) From 4fad31d67a30956dea0192699a36b77e923785ce Mon Sep 17 00:00:00 2001 From: CoderPat Date: Tue, 4 May 2021 12:07:20 +0100 Subject: [PATCH 10/11] remove stale class --- src/transformers/modeling_flax_utils.py | 111 ------------------------ 1 file changed, 111 deletions(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 682465311b59d4..51e65f37b2a2d6 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -57,10 +57,6 @@ } -def identity(x, **kwargs): - return x - - class FlaxPreTrainedModel(PushToHubMixin): r""" Base class for all models. @@ -431,113 +427,6 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub=F logger.info(f"Model pushed to the hub in this commit: {url}") -class SequenceSummary(nn.Module): - r""" - Compute a single vector summary of a sequence hidden states. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The config used by the model. Relevant arguments in the config class of the model are (refer to the actual - config class of your model for the default values it uses): - - - **summary_type** (:obj:`str`) -- The method to use to make this summary. Accepted values are: - - - :obj:`"last"` -- Take the last token hidden state (like XLNet) - - :obj:`"first"` -- Take the first token hidden state (like Bert) - - :obj:`"mean"` -- Take the mean of all tokens hidden states - - :obj:`"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) - - :obj:`"attn"` -- Not implemented now, use multi-head attention - - - **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction. - - **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to - :obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`). - - **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the - output, another string or :obj:`None` will add no activation. - - **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and - activation. - - **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and - activation. - """ - config: PretrainedConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - - self.summary_type = getattr(self.config, "summary_type", "last") - if self.summary_type == "attn": - # We should use a standard multi-head attention module with absolute positional embedding for that. - # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 - # We can probably just use the multi-head attention module of PyTorch >=1.1.0 - raise NotImplementedError - - self.summary = identity - if hasattr(self.config, "summary_use_proj") and self.config.summary_use_proj: - if ( - hasattr(self.config, "summary_proj_to_labels") - and self.config.summary_proj_to_labels - and self.config.num_labels > 0 - ): - num_classes = self.config.num_labels - else: - num_classes = self.config.hidden_size - self.summary = nn.Dense(num_classes, dtype=self.dtype) - - activation_string = getattr(self.config, "summary_activation", None) - self.activation = ACT2FN[activation_string] if activation_string else lambda x: x - - self.first_dropout = identity - if hasattr(self.config, "summary_first_dropout") and self.config.summary_first_dropout > 0: - self.first_dropout = nn.Dropout(self.config.summary_first_dropout) - - self.last_dropout = identity - if hasattr(self.config, "summary_last_dropout") and self.config.summary_last_dropout > 0: - self.last_dropout = nn.Dropout(self.config.summary_last_dropout) - - def __call__(self, hidden_states, cls_index=None, deterministic: bool = True): - """ - Compute a single vector summary of a sequence hidden states. - - Args: - hidden_states (:obj:`jnp.array` of shape :obj:`[batch_size, seq_len, hidden_size]`): - The hidden states of the last layer. - cls_index (:obj:`jnp.array` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are optional leading dimensions of :obj:`hidden_states`, `optional`): - Used if :obj:`summary_type == "cls_index"` and takes the last token of the sequence as classification - token. - - Returns: - :obj:`jnp.array`: The summary of the sequence hidden states. - """ - if self.summary_type == "last": - output = hidden_states[:, -1] - elif self.summary_type == "first": - output = hidden_states[:, 0] - elif self.summary_type == "mean": - output = hidden_states.mean(dim=1) - elif self.summary_type == "cls_index": - if cls_index is None: - cls_index = jnp.full_like( - hidden_states[..., :1, :], - hidden_states.shape[-2] - 1, - dtype=jnp.long, - ) - else: - # TODO: - raise NotImplementedError - # cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) - # cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) - # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states - output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) - elif self.summary_type == "attn": - raise NotImplementedError - - output = self.first_dropout(output, deterministic=deterministic) - output = self.summary(output) - output = self.activation(output) - output = self.last_dropout(output, deterministic=deterministic) - - return output - - def overwrite_call_docstring(model_class, docstring): # copy __call__ function to be sure docstring is changed only for this function model_class.__call__ = copy_func(model_class.__call__) From b60c2ffe36153af20d4520da1256fb27c51acfe9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 4 May 2021 20:40:43 +0200 Subject: [PATCH 11/11] add copied from --- src/transformers/models/electra/modeling_flax_electra.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index 66ef9a51871ebc..9482e2263d10a9 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -166,6 +166,7 @@ def setup(self): self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings.__call__ def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): # Embed inputs_embeds = self.word_embeddings(input_ids.astype("i4"))