Skip to content

Commit

Permalink
[Flax] FlaxAutoModelForSeq2SeqLM (huggingface#12228)
Browse files Browse the repository at this point in the history
* add FlaxAutoModelForSeq2SeqLM
  • Loading branch information
patil-suraj authored and Iwontbecreative committed Jul 15, 2021
1 parent 4fa9427 commit 5cc3440
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 0 deletions.
7 changes: 7 additions & 0 deletions docs/source/model_doc/auto.rst
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,13 @@ FlaxAutoModelForMaskedLM
:members:


FlaxAutoModelForSeq2SeqLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxAutoModelForSeq2SeqLM
:members:


FlaxAutoModelForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,6 +1514,7 @@
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"FLAX_MODEL_FOR_PRETRAINING_MAPPING",
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_MAPPING",
Expand All @@ -1524,6 +1525,7 @@
"FlaxAutoModelForNextSentencePrediction",
"FlaxAutoModelForPreTraining",
"FlaxAutoModelForQuestionAnswering",
"FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForTokenClassification",
]
Expand Down Expand Up @@ -2851,6 +2853,7 @@
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
FLAX_MODEL_FOR_PRETRAINING_MAPPING,
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_MAPPING,
Expand All @@ -2861,6 +2864,7 @@
FlaxAutoModelForNextSentencePrediction,
FlaxAutoModelForPreTraining,
FlaxAutoModelForQuestionAnswering,
FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification,
FlaxAutoModelForTokenClassification,
)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"FLAX_MODEL_FOR_PRETRAINING_MAPPING",
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_MAPPING",
Expand All @@ -103,6 +104,7 @@
"FlaxAutoModelForNextSentencePrediction",
"FlaxAutoModelForPreTraining",
"FlaxAutoModelForQuestionAnswering",
"FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForTokenClassification",
]
Expand Down Expand Up @@ -178,6 +180,7 @@
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
FLAX_MODEL_FOR_PRETRAINING_MAPPING,
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_MAPPING,
Expand All @@ -189,6 +192,7 @@
FlaxAutoModelForNextSentencePrediction,
FlaxAutoModelForPreTraining,
FlaxAutoModelForQuestionAnswering,
FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification,
FlaxAutoModelForTokenClassification,
)
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@
]
)

FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
(BartConfig, FlaxBartForConditionalGeneration)
]
)

FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[
# Model for Sequence Classification mapping
Expand Down Expand Up @@ -197,6 +204,13 @@
"FlaxAutoModelForMaskedLM", FLAX_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
)


FlaxAutoModelForSeq2SeqLM = auto_class_factory(
"FlaxAutoModelForSeq2SeqLM",
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
head_doc="sequence-to-sequence language modeling",
)

FlaxAutoModelForSequenceClassification = auto_class_factory(
"FlaxAutoModelForSequenceClassification",
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/utils/dummy_flax_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def from_pretrained(cls, *args, **kwargs):
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None


FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None


FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None


Expand Down Expand Up @@ -166,6 +169,15 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxAutoModelForSeq2SeqLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxAutoModelForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
Expand Down

0 comments on commit 5cc3440

Please sign in to comment.