diff --git a/src/transformers/modeling_tf_auto.py b/src/transformers/modeling_tf_auto.py index b3d5bcc5155a78..7f697aa8c37991 100644 --- a/src/transformers/modeling_tf_auto.py +++ b/src/transformers/modeling_tf_auto.py @@ -121,6 +121,7 @@ from .modeling_tf_mobilebert import ( TFMobileBertForMaskedLM, TFMobileBertForMultipleChoice, + TFMobileBertForNextSentencePrediction, TFMobileBertForPreTraining, TFMobileBertForQuestionAnswering, TFMobileBertForSequenceClassification, @@ -359,6 +360,7 @@ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( [ (BertConfig, TFBertForNextSentencePrediction), + (MobileBertConfig, TFMobileBertForNextSentencePrediction), ] ) diff --git a/utils/check_repo.py b/utils/check_repo.py index f81b4c9fef3904..a563ff9471a23e 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -87,10 +87,8 @@ "RagSequenceForGeneration", "RagTokenForGeneration", "T5Stack", - "TFBertForNextSentencePrediction", "TFFunnelBaseModel", "TFGPT2DoubleHeadsModel", - "TFMobileBertForNextSentencePrediction", "TFOpenAIGPTDoubleHeadsModel", "XLMForQuestionAnswering", "XLMProphetNetDecoder",