Skip to content

Commit

Permalink
[Flax] Add FlaxMBart (#12236)
Browse files Browse the repository at this point in the history
* Copy BART to MBart and rename some stuff

* Add copy statements pointing to FlaxBart

* Update/add some common files

* Update shift_tokens_rigth + fix imports

* Fix shift_tokens_right method according to MBart implementation

* Update shift_tokens_right in tests accordingly

* Fix the import issue and update docs file
* make style quality

* Do some minor changes according to patil-suraj suggestions

* Change the order of normalization layer and attention

* Add some copu statementes

* Update generate method and add integration test for mBart

* Make a few updates after a review

Besides, add `lang_code_to_id` to MBartTokenizeFast

* fix-copies; make style quality

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

* fix output type, style

* add copied from

* resolve conflicts

Co-authored-by: Suraj Patil <surajp815@gmail.com>
  • Loading branch information
stancld and patil-suraj authored Jul 7, 2021
1 parent 2d42915 commit 61400e1
Show file tree
Hide file tree
Showing 9 changed files with 2,336 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLNet ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| mBART ||||| |
| mBART ||||| |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| mT5 ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
Expand Down
28 changes: 28 additions & 0 deletions docs/source/model_doc/mbart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,31 @@ TFMBartForConditionalGeneration

.. autoclass:: transformers.TFMBartForConditionalGeneration
:members: call


FlaxMBartModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxMBartModel
:members: __call__, encode, decode


FlaxMBartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxMBartForConditionalGeneration
:members: __call__, encode, decode


FlaxMBartForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxMBartForSequenceClassification
:members: __call__, encode, decode


FlaxMBartForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxMBartForQuestionAnswering
:members: __call__, encode, decode
16 changes: 16 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,6 +1633,15 @@
_import_structure["models.gpt_neo"].extend(
["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"]
)
_import_structure["models.mbart"].extend(
[
"FlaxMBartForConditionalGeneration",
"FlaxMBartForQuestionAnswering",
"FlaxMBartForSequenceClassification",
"FlaxMBartModel",
"FlaxMBartPreTrainedModel",
]
)
_import_structure["models.roberta"].extend(
[
"FlaxRobertaForMaskedLM",
Expand Down Expand Up @@ -3019,6 +3028,13 @@
)
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
from .models.gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel
from .models.mbart import (
FlaxMBartForConditionalGeneration,
FlaxMBartForQuestionAnswering,
FlaxMBartForSequenceClassification,
FlaxMBartModel,
FlaxMBartPreTrainedModel,
)
from .models.roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@
)
from ..gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
from ..gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel
from ..mbart.modeling_flax_mbart import (
FlaxMBartForConditionalGeneration,
FlaxMBartForQuestionAnswering,
FlaxMBartForSequenceClassification,
FlaxMBartModel,
)
from ..roberta.modeling_flax_roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
Expand All @@ -75,6 +81,7 @@
ElectraConfig,
GPT2Config,
GPTNeoConfig,
MBartConfig,
RobertaConfig,
T5Config,
ViTConfig,
Expand All @@ -97,6 +104,7 @@
(ElectraConfig, FlaxElectraModel),
(CLIPConfig, FlaxCLIPModel),
(ViTConfig, FlaxViTModel),
(MBartConfig, FlaxMBartModel),
(T5Config, FlaxT5Model),
(Wav2Vec2Config, FlaxWav2Vec2Model),
]
Expand All @@ -110,6 +118,7 @@
(BigBirdConfig, FlaxBigBirdForPreTraining),
(BartConfig, FlaxBartForConditionalGeneration),
(ElectraConfig, FlaxElectraForPreTraining),
(MBartConfig, FlaxMBartForConditionalGeneration),
(T5Config, FlaxT5ForConditionalGeneration),
(Wav2Vec2Config, FlaxWav2Vec2ForPreTraining),
]
Expand All @@ -123,6 +132,7 @@
(BigBirdConfig, FlaxBigBirdForMaskedLM),
(BartConfig, FlaxBartForConditionalGeneration),
(ElectraConfig, FlaxElectraForMaskedLM),
(MBartConfig, FlaxMBartForConditionalGeneration),
]
)

Expand Down Expand Up @@ -157,6 +167,7 @@
(BigBirdConfig, FlaxBigBirdForSequenceClassification),
(BartConfig, FlaxBartForSequenceClassification),
(ElectraConfig, FlaxElectraForSequenceClassification),
(MBartConfig, FlaxMBartForSequenceClassification),
]
)

Expand All @@ -168,6 +179,7 @@
(BigBirdConfig, FlaxBigBirdForQuestionAnswering),
(BartConfig, FlaxBartForQuestionAnswering),
(ElectraConfig, FlaxElectraForQuestionAnswering),
(MBartConfig, FlaxMBartForQuestionAnswering),
]
)

Expand Down
19 changes: 19 additions & 0 deletions src/transformers/models/mbart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ...file_utils import (
_BaseLazyModule,
is_flax_available,
is_sentencepiece_available,
is_tf_available,
is_tokenizers_available,
Expand Down Expand Up @@ -56,6 +57,15 @@
"TFMBartPreTrainedModel",
]

if is_flax_available():
_import_structure["modeling_flax_mbart"] = [
"FlaxMBartForConditionalGeneration",
"FlaxMBartForQuestionAnswering",
"FlaxMBartForSequenceClassification",
"FlaxMBartModel",
"FlaxMBartPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
Expand All @@ -82,6 +92,15 @@
if is_tf_available():
from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel

if is_flax_available():
from .modeling_flax_mbart import (
FlaxMBartForConditionalGeneration,
FlaxMBartForQuestionAnswering,
FlaxMBartForSequenceClassification,
FlaxMBartModel,
FlaxMBartPreTrainedModel,
)

else:
import importlib
import os
Expand Down
Loading

0 comments on commit 61400e1

Please sign in to comment.