diff --git a/docs/source/en/model_doc/auto.md b/docs/source/en/model_doc/auto.md index 722cafa87554..07ff0b8d6d86 100644 --- a/docs/source/en/model_doc/auto.md +++ b/docs/source/en/model_doc/auto.md @@ -214,6 +214,14 @@ The following auto classes are available for the following natural language proc [[autodoc]] FlaxAutoModelForQuestionAnswering +### AutoModelForTextEncoding + +[[autodoc]] AutoModelForTextEncoding + +### TFAutoModelForTextEncoding + +[[autodoc]] TFAutoModelForTextEncoding + ## Computer vision The following auto classes are available for the following computer vision tasks. diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 7bb7a4342ea3..9d3c9a4ff232 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1053,6 +1053,7 @@ "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_TEXT_ENCODING_MAPPING", "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING", @@ -1087,6 +1088,7 @@ "AutoModelForSequenceClassification", "AutoModelForSpeechSeq2Seq", "AutoModelForTableQuestionAnswering", + "AutoModelForTextEncoding", "AutoModelForTokenClassification", "AutoModelForUniversalSegmentation", "AutoModelForVideoClassification", @@ -2984,6 +2986,7 @@ "TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", "TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", + "TF_MODEL_FOR_TEXT_ENCODING_MAPPING", "TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_VISION_2_SEQ_MAPPING", "TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING", @@ -3003,6 +3006,7 @@ "TFAutoModelForSequenceClassification", "TFAutoModelForSpeechSeq2Seq", "TFAutoModelForTableQuestionAnswering", + "TFAutoModelForTextEncoding", "TFAutoModelForTokenClassification", "TFAutoModelForVision2Seq", "TFAutoModelForZeroShotImageClassification", @@ -4807,6 +4811,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_TEXT_ENCODING_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, @@ -4841,6 +4846,7 @@ AutoModelForSequenceClassification, AutoModelForSpeechSeq2Seq, AutoModelForTableQuestionAnswering, + AutoModelForTextEncoding, AutoModelForTokenClassification, AutoModelForUniversalSegmentation, AutoModelForVideoClassification, @@ -6374,6 +6380,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, + TF_MODEL_FOR_TEXT_ENCODING_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, TF_MODEL_FOR_VISION_2_SEQ_MAPPING, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING, @@ -6393,6 +6400,7 @@ TFAutoModelForSequenceClassification, TFAutoModelForSpeechSeq2Seq, TFAutoModelForTableQuestionAnswering, + TFAutoModelForTextEncoding, TFAutoModelForTokenClassification, TFAutoModelForVision2Seq, TFAutoModelForZeroShotImageClassification, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index df9958a80e4d..26b6609c28fc 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -64,6 +64,7 @@ "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_TEXT_ENCODING_MAPPING", "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING", @@ -85,6 +86,7 @@ "AutoModelForImageSegmentation", "AutoModelForInstanceSegmentation", "AutoModelForMaskGeneration", + "AutoModelForTextEncoding", "AutoModelForMaskedImageModeling", "AutoModelForMaskedLM", "AutoModelForMultipleChoice", @@ -131,6 +133,7 @@ "TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", "TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", + "TF_MODEL_FOR_TEXT_ENCODING_MAPPING", "TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_VISION_2_SEQ_MAPPING", "TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING", @@ -150,6 +153,7 @@ "TFAutoModelForSequenceClassification", "TFAutoModelForSpeechSeq2Seq", "TFAutoModelForTableQuestionAnswering", + "TFAutoModelForTextEncoding", "TFAutoModelForTokenClassification", "TFAutoModelForVision2Seq", "TFAutoModelForZeroShotImageClassification", @@ -233,6 +237,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_TEXT_ENCODING_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, @@ -267,6 +272,7 @@ AutoModelForSequenceClassification, AutoModelForSpeechSeq2Seq, AutoModelForTableQuestionAnswering, + AutoModelForTextEncoding, AutoModelForTokenClassification, AutoModelForUniversalSegmentation, AutoModelForVideoClassification, @@ -300,6 +306,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, + TF_MODEL_FOR_TEXT_ENCODING_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, TF_MODEL_FOR_VISION_2_SEQ_MAPPING, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING, @@ -319,6 +326,7 @@ TFAutoModelForSequenceClassification, TFAutoModelForSpeechSeq2Seq, TFAutoModelForTableQuestionAnswering, + TFAutoModelForTextEncoding, TFAutoModelForTokenClassification, TFAutoModelForVision2Seq, TFAutoModelForZeroShotImageClassification, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 4c675ff05097..d0c4a236c7c0 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -1011,6 +1011,36 @@ ] ) +MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict( + [ + ("albert", "AlbertModel"), + ("bert", "BertModel"), + ("big_bird", "BigBirdModel"), + ("data2vec-text", "Data2VecTextModel"), + ("deberta", "DebertaModel"), + ("deberta-v2", "DebertaV2Model"), + ("distilbert", "DistilBertModel"), + ("electra", "ElectraModel"), + ("flaubert", "FlaubertModel"), + ("ibert", "IBertModel"), + ("longformer", "LongformerModel"), + ("mobilebert", "MobileBertModel"), + ("mt5", "MT5EncoderModel"), + ("nystromformer", "NystromformerModel"), + ("reformer", "ReformerModel"), + ("rembert", "RemBertModel"), + ("roberta", "RobertaModel"), + ("roberta-prelayernorm", "RobertaPreLayerNormModel"), + ("roc_bert", "RoCBertModel"), + ("roformer", "RoFormerModel"), + ("squeezebert", "SqueezeBertModel"), + ("t5", "T5EncoderModel"), + ("xlm", "XLMModel"), + ("xlm-roberta", "XLMRobertaModel"), + ("xlm-roberta-xl", "XLMRobertaXLModel"), + ] +) + MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) @@ -1088,11 +1118,17 @@ MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES) +MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES) + class AutoModelForMaskGeneration(_BaseAutoModelClass): _model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING +class AutoModelForTextEncoding(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING + + class AutoModel(_BaseAutoModelClass): _model_mapping = MODEL_MAPPING diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index bd86431c8cb2..10e33cbebbe6 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -437,6 +437,28 @@ ("sam", "TFSamModel"), ] ) +TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict( + [ + ("albert", "TFAlbertModel"), + ("bert", "TFBertModel"), + ("convbert", "TFConvBertModel"), + ("deberta", "TFDebertaModel"), + ("deberta-v2", "TFDebertaV2Model"), + ("distilbert", "TFDistilBertModel"), + ("electra", "TFElectraModel"), + ("flaubert", "TFFlaubertModel"), + ("longformer", "TFLongformerModel"), + ("mobilebert", "TFMobileBertModel"), + ("mt5", "TFMT5EncoderModel"), + ("rembert", "TFRemBertModel"), + ("roberta", "TFRobertaModel"), + ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"), + ("roformer", "TFRoFormerModel"), + ("t5", "TFT5EncoderModel"), + ("xlm", "TFXLMModel"), + ("xlm-roberta", "TFXLMRobertaModel"), + ] +) TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES) TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES) @@ -491,11 +513,17 @@ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES ) +TF_MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES) + class TFAutoModelForMaskGeneration(_BaseAutoModelClass): _model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING +class TFAutoModelForTextEncoding(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_TEXT_ENCODING_MAPPING + + class TFAutoModel(_BaseAutoModelClass): _model_mapping = TF_MODEL_MAPPING diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 1490adf3b20e..2ca6f0156222 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -524,6 +524,9 @@ def __init__(self, *args, **kwargs): MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None +MODEL_FOR_TEXT_ENCODING_MAPPING = None + + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None @@ -726,6 +729,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class AutoModelForTextEncoding(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class AutoModelForTokenClassification(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 4da32ae6034a..e55684402805 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -264,6 +264,9 @@ def __init__(self, *args, **kwargs): TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None +TF_MODEL_FOR_TEXT_ENCODING_MAPPING = None + + TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None @@ -377,6 +380,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +class TFAutoModelForTextEncoding(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFAutoModelForTokenClassification(metaclass=DummyObject): _backends = ["tf"]