Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax] Add Electra models #11426

Merged
merged 14 commits into from
May 4, 2021
23 changes: 23 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,6 +1390,19 @@
"FlaxBertPreTrainedModel",
]
)
_import_structure["models.electra"].extend(
[
"FlaxElectraForMaskedLM",
"FlaxElectraForMultipleChoice",
"FlaxElectraForPreTraining",
"FlaxElectraForQuestionAnswering",
"FlaxElectraForSequenceClassification",
"FlaxElectraForTokenClassification",
"FlaxElectraModel",
"FlaxElectraPreTrainedModel",
]
)

_import_structure["models.roberta"].append("FlaxRobertaModel")
else:
from .utils import dummy_flax_objects
Expand Down Expand Up @@ -2551,6 +2564,16 @@
FlaxBertModel,
FlaxBertPreTrainedModel,
)
from .models.electra import (
FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining,
FlaxElectraForQuestionAnswering,
FlaxElectraForSequenceClassification,
FlaxElectraForTokenClassification,
FlaxElectraModel,
FlaxElectraPreTrainedModel,
)
from .models.roberta import FlaxRobertaModel
else:
# Import the same objects as dummies to get them in the namespace.
Expand Down
136 changes: 135 additions & 1 deletion src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
import os
from functools import partial
from pickle import UnpicklingError
from typing import Dict, Set, Tuple, Union
from typing import Callable, Dict, Set, Tuple, Union

import numpy as np

import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey

from .configuration_utils import PretrainedConfig
Expand Down Expand Up @@ -55,6 +58,10 @@
}


def identity(x, **kwargs):
CoderPat marked this conversation as resolved.
Show resolved Hide resolved
return x


class FlaxPreTrainedModel(PushToHubMixin):
r"""
Base class for all models.
Expand Down Expand Up @@ -425,6 +432,133 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub=F
logger.info(f"Model pushed to the hub in this commit: {url}")


class SequenceSummary(nn.Module):
CoderPat marked this conversation as resolved.
Show resolved Hide resolved
r"""
Compute a single vector summary of a sequence hidden states.

Args:
config (:class:`~transformers.PretrainedConfig`):
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
config class of your model for the default values it uses):

- **summary_type** (:obj:`str`) -- The method to use to make this summary. Accepted values are:

- :obj:`"last"` -- Take the last token hidden state (like XLNet)
- :obj:`"first"` -- Take the first token hidden state (like Bert)
- :obj:`"mean"` -- Take the mean of all tokens hidden states
- :obj:`"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
- :obj:`"attn"` -- Not implemented now, use multi-head attention

- **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction.
- **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to
:obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`).
- **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the
output, another string or :obj:`None` will add no activation.
- **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and
activation.
- **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and
activation.
"""
config: PretrainedConfig
dtype: jnp.dtype = jnp.float32

def setup(self):

self.summary_type = getattr(self.config, "summary_type", "last")
if self.summary_type == "attn":
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise NotImplementedError

self.summary = identity
if hasattr(self.config, "summary_use_proj") and self.config.summary_use_proj:
if (
hasattr(self.config, "summary_proj_to_labels")
and self.config.summary_proj_to_labels
and self.config.num_labels > 0
):
num_classes = self.config.num_labels
else:
num_classes = self.config.hidden_size
self.summary = nn.Dense(num_classes, dtype=self.dtype)

activation_string = getattr(self.config, "summary_activation", None)
self.activation = ACT2FN[activation_string] if activation_string else lambda x: x

self.first_dropout = identity
if hasattr(self.config, "summary_first_dropout") and self.config.summary_first_dropout > 0:
self.first_dropout = nn.Dropout(self.config.summary_first_dropout)

self.last_dropout = identity
if hasattr(self.config, "summary_last_dropout") and self.config.summary_last_dropout > 0:
self.last_dropout = nn.Dropout(self.config.summary_last_dropout)

def __call__(self, hidden_states, cls_index=None, deterministic: bool = True):
"""
Compute a single vector summary of a sequence hidden states.

Args:
hidden_states (:obj:`jnp.array` of shape :obj:`[batch_size, seq_len, hidden_size]`):
The hidden states of the last layer.
cls_index (:obj:`jnp.array` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are optional leading dimensions of :obj:`hidden_states`, `optional`):
Used if :obj:`summary_type == "cls_index"` and takes the last token of the sequence as classification
token.

Returns:
:obj:`jnp.array`: The summary of the sequence hidden states.
"""
if self.summary_type == "last":
output = hidden_states[:, -1]
elif self.summary_type == "first":
output = hidden_states[:, 0]
elif self.summary_type == "mean":
output = hidden_states.mean(dim=1)
elif self.summary_type == "cls_index":
if cls_index is None:
cls_index = jnp.full_like(
hidden_states[..., :1, :],
hidden_states.shape[-2] - 1,
dtype=jnp.long,
)
else:
# TODO:
raise NotImplementedError
# cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
# cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
elif self.summary_type == "attn":
raise NotImplementedError

output = self.first_dropout(output, deterministic=deterministic)
output = self.summary(output)
output = self.activation(output)
output = self.last_dropout(output, deterministic=deterministic)

return output


class TiedDense(nn.Module):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add to do this gimmick since the way weight sharing in FlaxBert was done doesn't work for Electra and other models since they don't explictly separate the bias from the kernel (as is the case in Bert)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a cool class! However since ELECTRA doesn't tie word embeddings, I don't think we need it at the moment :-) Could we maybe leave it out for now and remove the tie embedding logic in ELECTRA?

embedding_size: int
dtype: jnp.dtype = jnp.float32
precision = None
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros

def setup(self):
bias = self.param("bias", self.bias_init, (self.embedding_size,))
self.bias = jnp.asarray(bias, dtype=self.dtype)

def __call__(self, x, kernel):
y = lax.dot_general(
x,
kernel,
(((x.ndim - 1,), (0,)), ((), ())),
precision=self.precision,
)
return y + self.bias


def overwrite_call_docstring(model_class, docstring):
# copy __call__ function to be sure docstring is changed only for this function
model_class.__call__ = copy_func(model_class.__call__)
Expand Down
53 changes: 26 additions & 27 deletions src/transformers/models/electra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@

from typing import TYPE_CHECKING

from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_available, is_torch_available
from ...file_utils import (
_BaseLazyModule,
is_flax_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)


_import_structure = {
Expand Down Expand Up @@ -56,40 +62,33 @@
"TFElectraPreTrainedModel",
]

if is_flax_available():
_import_structure["modeling_flax_electra"] = [
"FlaxElectraForMaskedLM",
"FlaxElectraForMultipleChoice",
"FlaxElectraForPreTraining",
"FlaxElectraForQuestionAnswering",
"FlaxElectraForSequenceClassification",
"FlaxElectraForTokenClassification",
"FlaxElectraModel",
"FlaxElectraPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
from .tokenization_electra import ElectraTokenizer
pass

if is_tokenizers_available():
from .tokenization_electra_fast import ElectraTokenizerFast
pass

if is_torch_available():
from .modeling_electra import (
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
ElectraForMaskedLM,
ElectraForMultipleChoice,
ElectraForPreTraining,
ElectraForQuestionAnswering,
ElectraForSequenceClassification,
ElectraForTokenClassification,
ElectraModel,
ElectraPreTrainedModel,
load_tf_weights_in_electra,
)
pass

if is_tf_available():
from .modeling_tf_electra import (
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFElectraForMaskedLM,
TFElectraForMultipleChoice,
TFElectraForPreTraining,
TFElectraForQuestionAnswering,
TFElectraForSequenceClassification,
TFElectraForTokenClassification,
TFElectraModel,
TFElectraPreTrainedModel,
)
pass

if is_flax_available():
pass

else:
import importlib
Expand Down
Loading