Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax] Add Electra models #11426

Merged
merged 14 commits into from
May 4, 2021
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DistilBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ELECTRA | ✅ | ✅ | ✅ | ✅ | |
| ELECTRA | ✅ | ✅ | ✅ | ✅ | |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yaaaay :-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😍

+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Encoder decoder | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
Expand Down
49 changes: 49 additions & 0 deletions docs/source/model_doc/electra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,52 @@ TFElectraForQuestionAnswering

.. autoclass:: transformers.TFElectraForQuestionAnswering
:members: call


FlaxElectraModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxElectraModel
:members: __call__


FlaxElectraForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxElectraForPreTraining
:members: __call__


FlaxElectraForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxElectraForMaskedLM
:members: __call__


FlaxElectraForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxElectraForSequenceClassification
:members: __call__


FlaxElectraForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxElectraForMultipleChoice
:members: __call__


FlaxElectraForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxElectraForTokenClassification
:members: __call__


FlaxElectraForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxElectraForQuestionAnswering
:members: __call__
22 changes: 22 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,6 +1403,18 @@
"FlaxBertPreTrainedModel",
]
)
_import_structure["models.electra"].extend(
[
"FlaxElectraForMaskedLM",
"FlaxElectraForMultipleChoice",
"FlaxElectraForPreTraining",
"FlaxElectraForQuestionAnswering",
"FlaxElectraForSequenceClassification",
"FlaxElectraForTokenClassification",
"FlaxElectraModel",
"FlaxElectraPreTrainedModel",
]
)
_import_structure["models.roberta"].extend(
[
"FlaxRobertaForMaskedLM",
Expand Down Expand Up @@ -2585,6 +2597,16 @@
FlaxBertModel,
FlaxBertPreTrainedModel,
)
from .models.electra import (
FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining,
FlaxElectraForQuestionAnswering,
FlaxElectraForSequenceClassification,
FlaxElectraForTokenClassification,
FlaxElectraModel,
FlaxElectraPreTrainedModel,
)
from .models.roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
Expand Down
18 changes: 17 additions & 1 deletion src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@
FlaxBertForTokenClassification,
FlaxBertModel,
)
from ..electra.modeling_flax_electra import (
FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining,
FlaxElectraForQuestionAnswering,
FlaxElectraForSequenceClassification,
FlaxElectraForTokenClassification,
FlaxElectraModel,
)
from ..roberta.modeling_flax_roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
Expand All @@ -37,7 +46,7 @@
FlaxRobertaModel,
)
from .auto_factory import auto_class_factory
from .configuration_auto import BertConfig, RobertaConfig
from .configuration_auto import BertConfig, ElectraConfig, RobertaConfig


logger = logging.get_logger(__name__)
Expand All @@ -48,6 +57,7 @@
# Base model mapping
(RobertaConfig, FlaxRobertaModel),
(BertConfig, FlaxBertModel),
(ElectraConfig, FlaxElectraModel),
]
)

Expand All @@ -56,6 +66,7 @@
# Model for pre-training mapping
(RobertaConfig, FlaxRobertaForMaskedLM),
(BertConfig, FlaxBertForPreTraining),
(ElectraConfig, FlaxElectraForPreTraining),
]
)

Expand All @@ -64,6 +75,7 @@
# Model for Masked LM mapping
(RobertaConfig, FlaxRobertaForMaskedLM),
(BertConfig, FlaxBertForMaskedLM),
(ElectraConfig, FlaxElectraForMaskedLM),
]
)

Expand All @@ -72,6 +84,7 @@
# Model for Sequence Classification mapping
(RobertaConfig, FlaxRobertaForSequenceClassification),
(BertConfig, FlaxBertForSequenceClassification),
(ElectraConfig, FlaxElectraForSequenceClassification),
]
)

Expand All @@ -80,6 +93,7 @@
# Model for Question Answering mapping
(RobertaConfig, FlaxRobertaForQuestionAnswering),
(BertConfig, FlaxBertForQuestionAnswering),
(ElectraConfig, FlaxElectraForQuestionAnswering),
]
)

Expand All @@ -88,6 +102,7 @@
# Model for Token Classification mapping
(RobertaConfig, FlaxRobertaForTokenClassification),
(BertConfig, FlaxBertForTokenClassification),
(ElectraConfig, FlaxElectraForTokenClassification),
]
)

Expand All @@ -96,6 +111,7 @@
# Model for Multiple Choice mapping
(RobertaConfig, FlaxRobertaForMultipleChoice),
(BertConfig, FlaxBertForMultipleChoice),
(ElectraConfig, FlaxElectraForMultipleChoice),
]
)

Expand Down
32 changes: 31 additions & 1 deletion src/transformers/models/electra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@

from typing import TYPE_CHECKING

from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_available, is_torch_available
from ...file_utils import (
_BaseLazyModule,
is_flax_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)


_import_structure = {
Expand Down Expand Up @@ -56,6 +62,18 @@
"TFElectraPreTrainedModel",
]

if is_flax_available():
_import_structure["modeling_flax_electra"] = [
"FlaxElectraForMaskedLM",
"FlaxElectraForMultipleChoice",
"FlaxElectraForPreTraining",
"FlaxElectraForQuestionAnswering",
"FlaxElectraForSequenceClassification",
"FlaxElectraForTokenClassification",
"FlaxElectraModel",
"FlaxElectraPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
Expand Down Expand Up @@ -91,6 +109,18 @@
TFElectraPreTrainedModel,
)

if is_flax_available():
from .modeling_flax_electra import (
FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining,
FlaxElectraForQuestionAnswering,
FlaxElectraForSequenceClassification,
FlaxElectraForTokenClassification,
FlaxElectraModel,
FlaxElectraPreTrainedModel,
)

else:
import importlib
import os
Expand Down
Loading