From d4996dc77f9af236fe210bfaf91e9f59e288eb69 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Mon, 9 Nov 2020 14:26:51 -0500 Subject: [PATCH] Check all models are in an auto class --- src/transformers/modeling_tf_auto.py | 4 ++ utils/check_repo.py | 69 ++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/src/transformers/modeling_tf_auto.py b/src/transformers/modeling_tf_auto.py index 6f0ea863a9ceba..98d3516bb8b7dc 100644 --- a/src/transformers/modeling_tf_auto.py +++ b/src/transformers/modeling_tf_auto.py @@ -31,6 +31,7 @@ FunnelConfig, GPT2Config, LongformerConfig, + LxmertConfig, MobileBertConfig, OpenAIGPTConfig, RobertaConfig, @@ -113,6 +114,7 @@ ) from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model from .modeling_tf_longformer import TFLongformerForMaskedLM, TFLongformerForQuestionAnswering, TFLongformerModel +from .modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel from .modeling_tf_marian import TFMarianMTModel from .modeling_tf_mbart import TFMBartForConditionalGeneration from .modeling_tf_mobilebert import ( @@ -168,6 +170,7 @@ TF_MODEL_MAPPING = OrderedDict( [ + (LxmertConfig, TFLxmertModel), (T5Config, TFT5Model), (DistilBertConfig, TFDistilBertModel), (AlbertConfig, TFAlbertModel), @@ -192,6 +195,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( [ + (LxmertConfig, TFLxmertForPreTraining), (T5Config, TFT5ForConditionalGeneration), (DistilBertConfig, TFDistilBertForMaskedLM), (AlbertConfig, TFAlbertForPreTraining), diff --git a/utils/check_repo.py b/utils/check_repo.py index b23eb13230540d..f81b4c9fef3904 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -70,6 +70,34 @@ "marian": "marian.rst", } +# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and +# should **not** be the rule. +IGNORE_NON_AUTO_CONFIGURED = [ + "DPRContextEncoder", + "DPREncoder", + "DPRReader", + "DPRSpanPredictor", + "FlaubertForQuestionAnswering", + "FunnelBaseModel", + "GPT2DoubleHeadsModel", + "OpenAIGPTDoubleHeadsModel", + "ProphetNetDecoder", + "ProphetNetEncoder", + "RagModel", + "RagSequenceForGeneration", + "RagTokenForGeneration", + "T5Stack", + "TFBertForNextSentencePrediction", + "TFFunnelBaseModel", + "TFGPT2DoubleHeadsModel", + "TFMobileBertForNextSentencePrediction", + "TFOpenAIGPTDoubleHeadsModel", + "XLMForQuestionAnswering", + "XLMProphetNetDecoder", + "XLMProphetNetEncoder", + "XLNetForQuestionAnswering", +] + # This is to make sure the transformers module imported is the one in the repo. spec = importlib.util.spec_from_file_location( "transformers", @@ -282,6 +310,45 @@ def check_all_models_are_documented(): raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures)) +def get_all_auto_configured_models(): + """ Return the list of all models in at least one auto class.""" + result = set() # To avoid duplicates we concatenate all model classes in a set. + for attr_name in dir(transformers.modeling_auto): + if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"): + result = result | set(getattr(transformers.modeling_auto, attr_name).values()) + for attr_name in dir(transformers.modeling_tf_auto): + if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"): + result = result | set(getattr(transformers.modeling_tf_auto, attr_name).values()) + return [cls.__name__ for cls in result] + + +def check_models_are_auto_configured(module, all_auto_models): + """ Check models defined in module are each in an auto class.""" + defined_models = get_models(module) + failures = [] + for model_name, _ in defined_models: + if model_name not in all_auto_models and model_name not in IGNORE_NON_AUTO_CONFIGURED: + failures.append( + f"{model_name} is defined in {module.__name__} but is not present in any of the auto mapping. " + "If that is intended behavior, add its name to `IGNORE_NON_AUTO_CONFIGURED` in the file " + "`utils/check_repo.py`." + ) + return failures + + +def check_all_models_are_auto_configured(): + """ Check all models are each in an auto class.""" + modules = get_model_modules() + all_auto_models = get_all_auto_configured_models() + failures = [] + for module in modules: + new_failures = check_models_are_auto_configured(module, all_auto_models) + if new_failures is not None: + failures += new_failures + if len(failures) > 0: + raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures)) + + _re_decorator = re.compile(r"^\s*@(\S+)\s+$") @@ -325,6 +392,8 @@ def check_repo_quality(): check_all_models_are_tested() print("Checking all models are properly documented.") check_all_models_are_documented() + print("Checking all models are in at least one auto class.") + check_all_models_are_auto_configured() if __name__ == "__main__":