Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check all models are in an auto class #8425

Merged
merged 1 commit into from
Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
FunnelConfig,
GPT2Config,
LongformerConfig,
LxmertConfig,
MobileBertConfig,
OpenAIGPTConfig,
RobertaConfig,
Expand Down Expand Up @@ -113,6 +114,7 @@
)
from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
from .modeling_tf_longformer import TFLongformerForMaskedLM, TFLongformerForQuestionAnswering, TFLongformerModel
from .modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel
from .modeling_tf_marian import TFMarianMTModel
from .modeling_tf_mbart import TFMBartForConditionalGeneration
from .modeling_tf_mobilebert import (
Expand Down Expand Up @@ -168,6 +170,7 @@

TF_MODEL_MAPPING = OrderedDict(
[
(LxmertConfig, TFLxmertModel),
(T5Config, TFT5Model),
(DistilBertConfig, TFDistilBertModel),
(AlbertConfig, TFAlbertModel),
Expand All @@ -192,6 +195,7 @@

TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
[
(LxmertConfig, TFLxmertForPreTraining),
(T5Config, TFT5ForConditionalGeneration),
(DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForPreTraining),
Expand Down
69 changes: 69 additions & 0 deletions utils/check_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,34 @@
"marian": "marian.rst",
}

# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
# should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = [
"DPRContextEncoder",
"DPREncoder",
"DPRReader",
"DPRSpanPredictor",
"FlaubertForQuestionAnswering",
"FunnelBaseModel",
"GPT2DoubleHeadsModel",
"OpenAIGPTDoubleHeadsModel",
"ProphetNetDecoder",
"ProphetNetEncoder",
"RagModel",
"RagSequenceForGeneration",
"RagTokenForGeneration",
"T5Stack",
"TFBertForNextSentencePrediction",
"TFFunnelBaseModel",
"TFGPT2DoubleHeadsModel",
"TFMobileBertForNextSentencePrediction",
"TFOpenAIGPTDoubleHeadsModel",
"XLMForQuestionAnswering",
"XLMProphetNetDecoder",
"XLMProphetNetEncoder",
"XLNetForQuestionAnswering",
]

# This is to make sure the transformers module imported is the one in the repo.
spec = importlib.util.spec_from_file_location(
"transformers",
Expand Down Expand Up @@ -282,6 +310,45 @@ def check_all_models_are_documented():
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))


def get_all_auto_configured_models():
""" Return the list of all models in at least one auto class."""
result = set() # To avoid duplicates we concatenate all model classes in a set.
for attr_name in dir(transformers.modeling_auto):
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(getattr(transformers.modeling_auto, attr_name).values())
for attr_name in dir(transformers.modeling_tf_auto):
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(getattr(transformers.modeling_tf_auto, attr_name).values())
return [cls.__name__ for cls in result]


def check_models_are_auto_configured(module, all_auto_models):
""" Check models defined in module are each in an auto class."""
defined_models = get_models(module)
failures = []
for model_name, _ in defined_models:
if model_name not in all_auto_models and model_name not in IGNORE_NON_AUTO_CONFIGURED:
failures.append(
f"{model_name} is defined in {module.__name__} but is not present in any of the auto mapping. "
"If that is intended behavior, add its name to `IGNORE_NON_AUTO_CONFIGURED` in the file "
"`utils/check_repo.py`."
)
return failures


def check_all_models_are_auto_configured():
""" Check all models are each in an auto class."""
modules = get_model_modules()
all_auto_models = get_all_auto_configured_models()
failures = []
for module in modules:
new_failures = check_models_are_auto_configured(module, all_auto_models)
if new_failures is not None:
failures += new_failures
if len(failures) > 0:
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))


_re_decorator = re.compile(r"^\s*@(\S+)\s+$")


Expand Down Expand Up @@ -325,6 +392,8 @@ def check_repo_quality():
check_all_models_are_tested()
print("Checking all models are properly documented.")
check_all_models_are_documented()
print("Checking all models are in at least one auto class.")
check_all_models_are_auto_configured()


if __name__ == "__main__":
Expand Down