diff --git a/docs/source/model_doc/auto.rst b/docs/source/model_doc/auto.rst index 7b8ce142e044..69f67d7f56ff 100644 --- a/docs/source/model_doc/auto.rst +++ b/docs/source/model_doc/auto.rst @@ -226,6 +226,13 @@ FlaxAutoModelForMaskedLM :members: +FlaxAutoModelForSeq2SeqLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxAutoModelForSeq2SeqLM + :members: + + FlaxAutoModelForSequenceClassification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9fcf97b119da..7f367a68046f 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1514,6 +1514,7 @@ "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "FLAX_MODEL_FOR_PRETRAINING_MAPPING", "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "FLAX_MODEL_MAPPING", @@ -1524,6 +1525,7 @@ "FlaxAutoModelForNextSentencePrediction", "FlaxAutoModelForPreTraining", "FlaxAutoModelForQuestionAnswering", + "FlaxAutoModelForSeq2SeqLM", "FlaxAutoModelForSequenceClassification", "FlaxAutoModelForTokenClassification", ] @@ -2851,6 +2853,7 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, FLAX_MODEL_FOR_PRETRAINING_MAPPING, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, + FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, FLAX_MODEL_MAPPING, @@ -2861,6 +2864,7 @@ FlaxAutoModelForNextSentencePrediction, FlaxAutoModelForPreTraining, FlaxAutoModelForQuestionAnswering, + FlaxAutoModelForSeq2SeqLM, FlaxAutoModelForSequenceClassification, FlaxAutoModelForTokenClassification, ) diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 21238894787d..d483b271b873 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -92,6 +92,7 @@ "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "FLAX_MODEL_FOR_PRETRAINING_MAPPING", "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "FLAX_MODEL_MAPPING", @@ -103,6 +104,7 @@ "FlaxAutoModelForNextSentencePrediction", "FlaxAutoModelForPreTraining", "FlaxAutoModelForQuestionAnswering", + "FlaxAutoModelForSeq2SeqLM", "FlaxAutoModelForSequenceClassification", "FlaxAutoModelForTokenClassification", ] @@ -178,6 +180,7 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, FLAX_MODEL_FOR_PRETRAINING_MAPPING, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, + FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, FLAX_MODEL_MAPPING, @@ -189,6 +192,7 @@ FlaxAutoModelForNextSentencePrediction, FlaxAutoModelForPreTraining, FlaxAutoModelForQuestionAnswering, + FlaxAutoModelForSeq2SeqLM, FlaxAutoModelForSequenceClassification, FlaxAutoModelForTokenClassification, ) diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index ff59d35c6260..be03814c3be7 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -129,6 +129,13 @@ ] ) +FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( + [ + # Model for Seq2Seq Causal LM mapping + (BartConfig, FlaxBartForConditionalGeneration) + ] +) + FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( [ # Model for Sequence Classification mapping @@ -197,6 +204,13 @@ "FlaxAutoModelForMaskedLM", FLAX_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling" ) + +FlaxAutoModelForSeq2SeqLM = auto_class_factory( + "FlaxAutoModelForSeq2SeqLM", + FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + head_doc="sequence-to-sequence language modeling", +) + FlaxAutoModelForSequenceClassification = auto_class_factory( "FlaxAutoModelForSequenceClassification", FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 7bae4a9a763e..7ad7ee76b6cd 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -94,6 +94,9 @@ def from_pretrained(cls, *args, **kwargs): FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None +FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None + + FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None @@ -166,6 +169,15 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) +class FlaxAutoModelForSeq2SeqLM: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxAutoModelForSequenceClassification: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"])