Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for multiple models for one config in auto classes #11150

Merged
merged 3 commits into from
Apr 8, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
# get abs dir
save_directory = os.path.abspath(save_directory)
# save config as well
self.config.architectures = [self.__class__.__name__[4:]]
self.config.save_pretrained(save_directory)

# save model
Expand Down
1 change: 1 addition & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,7 @@ def save_pretrained(self, save_directory, saved_model=False, version=1):
logger.info(f"Saved model created in {saved_model_dir}")

# Save configuration file
self.config.architectures = [self.__class__.__name__[2:]]
self.config.save_pretrained(save_directory)

# If we save using the predefined names, we can load using `from_pretrained`
Expand Down
28 changes: 24 additions & 4 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,26 @@
"""


def _get_model_class(config, model_mapping):
supported_models = model_mapping[type(config)]
if not isinstance(supported_models, (list, tuple)):
return supported_models

name_to_model = {model.__name__: model for model in supported_models}
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in name_to_model:
return name_to_model[arch]
elif f"TF{arch}" in name_to_model:
return name_to_model[f"TF{arch}"]
elif f"Flax{arch}" in name_to_model:
return name_to_model[f"Flax{arch}"]

# If not architecture is set in the config or match the supported models, the first element of the tuple is the
# defaults.
return supported_models[0]


class _BaseAutoModelClass:
# Base class for auto models.
_model_mapping = None
Expand All @@ -341,7 +361,8 @@ def __init__(self):

def from_config(cls, config, **kwargs):
if type(config) in cls._model_mapping.keys():
return cls._model_mapping[type(config)](config, **kwargs)
model_class = _get_model_class(config, cls._model_mapping)
return model_class(config, **kwargs)
raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
Expand All @@ -356,9 +377,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
)

if type(config) in cls._model_mapping.keys():
return cls._model_mapping[type(config)].from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
model_class = _get_model_class(config, cls._model_mapping)
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
Expand Down
10 changes: 8 additions & 2 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@
)


def _get_class_name(model_class):
if isinstance(model_class, (list, tuple)):
return " or ".join([c.__name__ for c in model_class])
return model_class.__name__


def _list_model_options(indent, config_to_class=None, use_model_types=True):
if config_to_class is None and not use_model_types:
raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
Expand All @@ -255,7 +261,7 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True):
model_type_to_name = {model_type: config.__name__ for model_type, config in CONFIG_MAPPING.items()}
else:
model_type_to_name = {
model_type: config_to_class[config].__name__
model_type: _get_class_name(config_to_class[config])
for model_type, config in CONFIG_MAPPING.items()
if config in config_to_class
}
Expand All @@ -264,7 +270,7 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True):
for model_type in sorted(model_type_to_name.keys())
]
else:
config_to_name = {config.__name__: clas.__name__ for config, clas in config_to_class.items()}
config_to_name = {config.__name__: _get_class_name(clas) for config, clas in config_to_class.items()}
config_to_model_name = {
config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items()
}
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@
)
from ..fsmt.modeling_fsmt import FSMTForConditionalGeneration, FSMTModel
from ..funnel.modeling_funnel import (
FunnelBaseModel,
FunnelForMaskedLM,
FunnelForMultipleChoice,
FunnelForPreTraining,
Expand Down Expand Up @@ -377,7 +378,7 @@
(CTRLConfig, CTRLModel),
(ElectraConfig, ElectraModel),
(ReformerConfig, ReformerModel),
(FunnelConfig, FunnelModel),
(FunnelConfig, (FunnelModel, FunnelBaseModel)),
(LxmertConfig, LxmertModel),
(BertGenerationConfig, BertGenerationEncoder),
(DebertaConfig, DebertaModel),
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
TFFlaubertWithLMHeadModel,
)
from ..funnel.modeling_tf_funnel import (
TFFunnelBaseModel,
TFFunnelForMaskedLM,
TFFunnelForMultipleChoice,
TFFunnelForPreTraining,
Expand Down Expand Up @@ -242,7 +243,7 @@
(XLMConfig, TFXLMModel),
(CTRLConfig, TFCTRLModel),
(ElectraConfig, TFElectraModel),
(FunnelConfig, TFFunnelModel),
(FunnelConfig, (TFFunnelModel, TFFunnelBaseModel)),
(DPRConfig, TFDPRQuestionEncoder),
(MPNetConfig, TFMPNetModel),
(BartConfig, TFBartModel),
Expand Down
32 changes: 28 additions & 4 deletions tests/test_modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import copy
import tempfile
import unittest

from transformers import is_torch_available
Expand Down Expand Up @@ -46,6 +47,8 @@
BertForSequenceClassification,
BertForTokenClassification,
BertModel,
FunnelBaseModel,
FunnelModel,
GPT2Config,
GPT2LMHeadModel,
RobertaForMaskedLM,
Expand Down Expand Up @@ -218,6 +221,21 @@ def test_from_identifier_from_model_type(self):
self.assertEqual(model.num_parameters(), 14410)
self.assertEqual(model.num_parameters(only_trainable=True), 14410)

def test_from_pretrained_with_tuple_values(self):
# For the auto model mapping, FunnelConfig has two models: FunnelModel and FunnelBaseModel
model = AutoModel.from_pretrained("sgugger/funnel-random-tiny")
self.assertIsInstance(model, FunnelModel)

config = copy.deepcopy(model.config)
config.architectures = ["FunnelBaseModel"]
model = AutoModel.from_config(config)
self.assertIsInstance(model, FunnelBaseModel)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model = AutoModel.from_pretrained(tmp_dir)
self.assertIsInstance(model, FunnelBaseModel)

def test_parents_and_children_in_mappings(self):
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
# by the parents and will return the wrong configuration type when using auto models
Expand All @@ -242,6 +260,12 @@ def test_parents_and_children_in_mappings(self):
assert not issubclass(
child_config, parent_config
), f"{child_config.__name__} is child of {parent_config.__name__}"
assert not issubclass(
child_model, parent_model
), f"{child_config.__name__} is child of {parent_config.__name__}"

# Tuplify child_model and parent_model since some of them could be tuples.
if not isinstance(child_model, (list, tuple)):
child_model = (child_model,)
if not isinstance(parent_model, (list, tuple)):
parent_model = (parent_model,)

for child, parent in [(a, b) for a in child_model for b in parent_model]:
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
2 changes: 2 additions & 0 deletions tests/test_modeling_flax_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
FlaxBertForNextSentencePrediction,
FlaxBertForPreTraining,
FlaxBertForQuestionAnswering,
FlaxBertForSequenceClassification,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not mine, it's the script. I'll let it know you liked ;-)

FlaxBertForTokenClassification,
FlaxBertModel,
)
Expand Down Expand Up @@ -125,6 +126,7 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
FlaxBertForMultipleChoice,
FlaxBertForQuestionAnswering,
FlaxBertForNextSentencePrediction,
FlaxBertForSequenceClassification,
FlaxBertForTokenClassification,
FlaxBertForQuestionAnswering,
)
Expand Down
30 changes: 28 additions & 2 deletions tests/test_modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import copy
import tempfile
import unittest

from transformers import is_tf_available
Expand All @@ -39,6 +40,8 @@
TFBertForQuestionAnswering,
TFBertForSequenceClassification,
TFBertModel,
TFFunnelBaseModel,
TFFunnelModel,
TFGPT2LMHeadModel,
TFRobertaForMaskedLM,
TFT5ForConditionalGeneration,
Expand Down Expand Up @@ -176,6 +179,21 @@ def test_from_identifier_from_model_type(self):
self.assertEqual(model.num_parameters(), 14410)
self.assertEqual(model.num_parameters(only_trainable=True), 14410)

def test_from_pretrained_with_tuple_values(self):
# For the auto model mapping, FunnelConfig has two models: FunnelModel and FunnelBaseModel
model = TFAutoModel.from_pretrained("sgugger/funnel-random-tiny")
self.assertIsInstance(model, TFFunnelModel)

config = copy.deepcopy(model.config)
config.architectures = ["FunnelBaseModel"]
model = TFAutoModel.from_config(config)
self.assertIsInstance(model, TFFunnelBaseModel)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model = TFAutoModel.from_pretrained(tmp_dir)
self.assertIsInstance(model, TFFunnelBaseModel)

def test_parents_and_children_in_mappings(self):
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
# by the parents and will return the wrong configuration type when using auto models
Expand All @@ -197,4 +215,12 @@ def test_parents_and_children_in_mappings(self):
for parent_config, parent_model in mapping[: index + 1]:
with self.subTest(msg=f"Testing if {child_config.__name__} is child of {parent_config.__name__}"):
self.assertFalse(issubclass(child_config, parent_config))
self.assertFalse(issubclass(child_model, parent_model))

# Tuplify child_model and parent_model since some of them could be tuples.
if not isinstance(child_model, (list, tuple)):
child_model = (child_model,)
if not isinstance(parent_model, (list, tuple)):
parent_model = (parent_model,)

for child, parent in [(a, b) for a in child_model for b in parent_model]:
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
22 changes: 17 additions & 5 deletions utils/check_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@
"DPRReader",
"DPRSpanPredictor",
"FlaubertForQuestionAnswering",
"FunnelBaseModel",
"GPT2DoubleHeadsModel",
"OpenAIGPTDoubleHeadsModel",
"RagModel",
Expand All @@ -95,7 +94,6 @@
"T5Stack",
"TFDPRReader",
"TFDPRSpanPredictor",
"TFFunnelBaseModel",
"TFGPT2DoubleHeadsModel",
"TFOpenAIGPTDoubleHeadsModel",
"TFRagModel",
Expand Down Expand Up @@ -153,7 +151,7 @@ def get_model_modules():
def get_models(module):
""" Get the objects in module that are models."""
models = []
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel)
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
for attr_name in dir(module):
if "Pretrained" in attr_name or "PreTrained" in attr_name:
continue
Expand Down Expand Up @@ -244,15 +242,29 @@ def check_all_models_are_tested():
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))


def _list_models(model_mapping):
result = []
for model in model_mapping.values():
if isinstance(model, (list, tuple)):
result += list(model)
else:
result.append(model)

return result


def get_all_auto_configured_models():
""" Return the list of all models in at least one auto class."""
result = set() # To avoid duplicates we concatenate all model classes in a set.
for attr_name in dir(transformers.models.auto.modeling_auto):
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(getattr(transformers.models.auto.modeling_auto, attr_name).values())
result = result | set(_list_models(getattr(transformers.models.auto.modeling_auto, attr_name)))
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(getattr(transformers.models.auto.modeling_tf_auto, attr_name).values())
result = result | set(_list_models(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(_list_models(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
return [cls.__name__ for cls in result]


Expand Down