diff --git a/src/transformers/modeling_tf_auto.py b/src/transformers/modeling_tf_auto.py index 02bc7846245e7e..dcf694e3395236 100644 --- a/src/transformers/modeling_tf_auto.py +++ b/src/transformers/modeling_tf_auto.py @@ -61,6 +61,7 @@ from .modeling_tf_bert import ( TFBertForMaskedLM, TFBertForMultipleChoice, + TFBertForNextSentencePrediction, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, @@ -120,6 +121,7 @@ from .modeling_tf_mobilebert import ( TFMobileBertForMaskedLM, TFMobileBertForMultipleChoice, + TFMobileBertForNextSentencePrediction, TFMobileBertForPreTraining, TFMobileBertForQuestionAnswering, TFMobileBertForSequenceClassification, @@ -355,6 +357,13 @@ ] ) +TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( + [ + (BertConfig, TFBertForNextSentencePrediction), + (MobileBertConfig, TFMobileBertForNextSentencePrediction), + ] +) + TF_AUTO_MODEL_PRETRAINED_DOCSTRING = r""" @@ -1412,3 +1421,101 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): ", ".join(c.__name__ for c in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()), ) ) + + +class TFAutoModelForNextSentencePrediction: + r""" + This is a generic model class that will be instantiated as one of the model classes of the library---with a + multiple choice classification head---when created with the when created with the + :meth:`~transformers.TFAutoModelForNextSentencePrediction.from_pretrained` class method or the + :meth:`~transformers.TFAutoModelForNextSentencePrediction.from_config` class method. + + This class cannot be instantiated directly using ``__init__()`` (throws an error). + """ + + def __init__(self): + raise EnvironmentError( + "TFAutoModelForNextSentencePrediction is designed to be instantiated " + "using the `TFAutoModelForNextSentencePrediction.from_pretrained(pretrained_model_name_or_path)` or " + "`TFAutoModelForNextSentencePrediction.from_config(config)` methods." + ) + + @classmethod + @replace_list_option_in_docstrings(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, use_model_types=False) + def from_config(cls, config): + r""" + Instantiates one of the model classes of the library---with a next sentence prediction head---from a + configuration. + + Note: + Loading a model from its configuration file does **not** load the model weights. It only affects the + model's configuration. Use :meth:`~transformers.TFAutoModelForNextSentencePrediction.from_pretrained` to + load the model weights. + + Args: + config (:class:`~transformers.PretrainedConfig`): + The model class to instantiate is selected based on the configuration class: + + List options + + Examples:: + + >>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction + >>> # Download configuration from S3 and cache. + >>> config = AutoConfig.from_pretrained('bert-base-uncased') + >>> model = TFAutoModelForNextSentencePrediction.from_config(config) + """ + if type(config) in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys(): + return TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)](config) + raise ValueError( + "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" + "Model type should be one of {}.".format( + config.__class__, + cls.__name__, + ", ".join(c.__name__ for c in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()), + ) + ) + + @classmethod + @replace_list_option_in_docstrings(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING) + @add_start_docstrings( + "Instantiate one of the model classes of the library---with a next sentence prediction head---from a " + "pretrained model.", + TF_AUTO_MODEL_PRETRAINED_DOCSTRING, + ) + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" + Examples:: + + >>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction + + >>> # Download model and configuration from S3 and cache. + >>> model = TFAutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased') + + >>> # Update configuration during loading + >>> model = TFAutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased', output_attentions=True) + >>> model.config.output_attentions + True + + >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) + >>> config = AutoConfig.from_json_file('./pt_model/bert_pt_model_config.json') + >>> model = TFAutoModelForNextSentencePrediction.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) + """ + config = kwargs.pop("config", None) + if not isinstance(config, PretrainedConfig): + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) + + if type(config) in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys(): + return TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)].from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + raise ValueError( + "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" + "Model type should be one of {}.".format( + config.__class__, + cls.__name__, + ", ".join(c.__name__ for c in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()), + ) + ) diff --git a/utils/check_repo.py b/utils/check_repo.py index f81b4c9fef3904..a563ff9471a23e 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -87,10 +87,8 @@ "RagSequenceForGeneration", "RagTokenForGeneration", "T5Stack", - "TFBertForNextSentencePrediction", "TFFunnelBaseModel", "TFGPT2DoubleHeadsModel", - "TFMobileBertForNextSentencePrediction", "TFOpenAIGPTDoubleHeadsModel", "XLMForQuestionAnswering", "XLMProphetNetDecoder",