Skip to content

Commit

Permalink
Check all models are in an auto class (huggingface#8425)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger authored and stas00 committed Nov 10, 2020
1 parent a542034 commit 8fe4159
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/transformers/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
FunnelConfig,
GPT2Config,
LongformerConfig,
LxmertConfig,
MobileBertConfig,
OpenAIGPTConfig,
RobertaConfig,
Expand Down Expand Up @@ -113,6 +114,7 @@
)
from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
from .modeling_tf_longformer import TFLongformerForMaskedLM, TFLongformerForQuestionAnswering, TFLongformerModel
from .modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel
from .modeling_tf_marian import TFMarianMTModel
from .modeling_tf_mbart import TFMBartForConditionalGeneration
from .modeling_tf_mobilebert import (
Expand Down Expand Up @@ -168,6 +170,7 @@

TF_MODEL_MAPPING = OrderedDict(
[
(LxmertConfig, TFLxmertModel),
(T5Config, TFT5Model),
(DistilBertConfig, TFDistilBertModel),
(AlbertConfig, TFAlbertModel),
Expand All @@ -192,6 +195,7 @@

TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
[
(LxmertConfig, TFLxmertForPreTraining),
(T5Config, TFT5ForConditionalGeneration),
(DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForPreTraining),
Expand Down
69 changes: 69 additions & 0 deletions utils/check_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,34 @@
"marian": "marian.rst",
}

# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
# should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = [
"DPRContextEncoder",
"DPREncoder",
"DPRReader",
"DPRSpanPredictor",
"FlaubertForQuestionAnswering",
"FunnelBaseModel",
"GPT2DoubleHeadsModel",
"OpenAIGPTDoubleHeadsModel",
"ProphetNetDecoder",
"ProphetNetEncoder",
"RagModel",
"RagSequenceForGeneration",
"RagTokenForGeneration",
"T5Stack",
"TFBertForNextSentencePrediction",
"TFFunnelBaseModel",
"TFGPT2DoubleHeadsModel",
"TFMobileBertForNextSentencePrediction",
"TFOpenAIGPTDoubleHeadsModel",
"XLMForQuestionAnswering",
"XLMProphetNetDecoder",
"XLMProphetNetEncoder",
"XLNetForQuestionAnswering",
]

# This is to make sure the transformers module imported is the one in the repo.
spec = importlib.util.spec_from_file_location(
"transformers",
Expand Down Expand Up @@ -282,6 +310,45 @@ def check_all_models_are_documented():
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))


def get_all_auto_configured_models():
""" Return the list of all models in at least one auto class."""
result = set() # To avoid duplicates we concatenate all model classes in a set.
for attr_name in dir(transformers.modeling_auto):
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(getattr(transformers.modeling_auto, attr_name).values())
for attr_name in dir(transformers.modeling_tf_auto):
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(getattr(transformers.modeling_tf_auto, attr_name).values())
return [cls.__name__ for cls in result]


def check_models_are_auto_configured(module, all_auto_models):
""" Check models defined in module are each in an auto class."""
defined_models = get_models(module)
failures = []
for model_name, _ in defined_models:
if model_name not in all_auto_models and model_name not in IGNORE_NON_AUTO_CONFIGURED:
failures.append(
f"{model_name} is defined in {module.__name__} but is not present in any of the auto mapping. "
"If that is intended behavior, add its name to `IGNORE_NON_AUTO_CONFIGURED` in the file "
"`utils/check_repo.py`."
)
return failures


def check_all_models_are_auto_configured():
""" Check all models are each in an auto class."""
modules = get_model_modules()
all_auto_models = get_all_auto_configured_models()
failures = []
for module in modules:
new_failures = check_models_are_auto_configured(module, all_auto_models)
if new_failures is not None:
failures += new_failures
if len(failures) > 0:
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))


_re_decorator = re.compile(r"^\s*@(\S+)\s+$")


Expand Down Expand Up @@ -325,6 +392,8 @@ def check_repo_quality():
check_all_models_are_tested()
print("Checking all models are properly documented.")
check_all_models_are_documented()
print("Checking all models are in at least one auto class.")
check_all_models_are_auto_configured()


if __name__ == "__main__":
Expand Down

0 comments on commit 8fe4159

Please sign in to comment.