Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for multiple models for one config in auto classes #11150

Merged
merged 3 commits into from
Apr 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
# get abs dir
save_directory = os.path.abspath(save_directory)
# save config as well
self.config.architectures = [self.__class__.__name__[4:]]
self.config.save_pretrained(save_directory)

# save model
Expand Down
1 change: 1 addition & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,7 @@ def save_pretrained(self, save_directory, saved_model=False, version=1):
logger.info(f"Saved model created in {saved_model_dir}")

# Save configuration file
self.config.architectures = [self.__class__.__name__[2:]]
self.config.save_pretrained(save_directory)

# If we save using the predefined names, we can load using `from_pretrained`
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@


_import_structure = {
"auto_factory": ["get_values"],
"configuration_auto": ["ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"],
"feature_extraction_auto": ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"],
"tokenization_auto": ["TOKENIZER_MAPPING", "AutoTokenizer"],
Expand Down Expand Up @@ -104,6 +105,7 @@


if TYPE_CHECKING:
from .auto_factory import get_values
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
Expand Down
39 changes: 35 additions & 4 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,26 @@
"""


def _get_model_class(config, model_mapping):
supported_models = model_mapping[type(config)]
if not isinstance(supported_models, (list, tuple)):
return supported_models

name_to_model = {model.__name__: model for model in supported_models}
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in name_to_model:
return name_to_model[arch]
elif f"TF{arch}" in name_to_model:
return name_to_model[f"TF{arch}"]
elif f"Flax{arch}" in name_to_model:
return name_to_model[f"Flax{arch}"]

# If not architecture is set in the config or match the supported models, the first element of the tuple is the
# defaults.
return supported_models[0]


class _BaseAutoModelClass:
# Base class for auto models.
_model_mapping = None
Expand All @@ -341,7 +361,8 @@ def __init__(self):

def from_config(cls, config, **kwargs):
if type(config) in cls._model_mapping.keys():
return cls._model_mapping[type(config)](config, **kwargs)
model_class = _get_model_class(config, cls._model_mapping)
return model_class(config, **kwargs)
raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
Expand All @@ -356,9 +377,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
)

if type(config) in cls._model_mapping.keys():
return cls._model_mapping[type(config)].from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
model_class = _get_model_class(config, cls._model_mapping)
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
Expand Down Expand Up @@ -418,3 +438,14 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained)
new_class.from_pretrained = classmethod(from_pretrained)
return new_class


def get_values(model_mapping):
result = []
for model in model_mapping.values():
if isinstance(model, (list, tuple)):
result += list(model)
else:
result.append(model)

return result
19 changes: 14 additions & 5 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,29 +247,38 @@
)


def _get_class_name(model_class):
if isinstance(model_class, (list, tuple)):
return " or ".join([f":class:`~transformers.{c.__name__}`" for c in model_class])
return f":class:`~transformers.{model_class.__name__}`"
Comment on lines +250 to +253
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change and below are for the docstrings to properly render a model type with several models possible:
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh this is very nice



def _list_model_options(indent, config_to_class=None, use_model_types=True):
if config_to_class is None and not use_model_types:
raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
if use_model_types:
if config_to_class is None:
model_type_to_name = {model_type: config.__name__ for model_type, config in CONFIG_MAPPING.items()}
model_type_to_name = {
model_type: f":class:`~transformers.{config.__name__}`"
for model_type, config in CONFIG_MAPPING.items()
}
else:
model_type_to_name = {
model_type: config_to_class[config].__name__
model_type: _get_class_name(config_to_class[config])
for model_type, config in CONFIG_MAPPING.items()
if config in config_to_class
}
lines = [
f"{indent}- **{model_type}** -- :class:`~transformers.{model_type_to_name[model_type]}` ({MODEL_NAMES_MAPPING[model_type]} model)"
f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
for model_type in sorted(model_type_to_name.keys())
]
else:
config_to_name = {config.__name__: clas.__name__ for config, clas in config_to_class.items()}
config_to_name = {config.__name__: _get_class_name(clas) for config, clas in config_to_class.items()}
config_to_model_name = {
config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items()
}
lines = [
f"{indent}- :class:`~transformers.{config_name}` configuration class: :class:`~transformers.{config_to_name[config_name]}` ({config_to_model_name[config_name]} model)"
f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
for config_name in sorted(config_to_name.keys())
]
return "\n".join(lines)
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@
)
from ..fsmt.modeling_fsmt import FSMTForConditionalGeneration, FSMTModel
from ..funnel.modeling_funnel import (
FunnelBaseModel,
FunnelForMaskedLM,
FunnelForMultipleChoice,
FunnelForPreTraining,
Expand Down Expand Up @@ -377,7 +378,7 @@
(CTRLConfig, CTRLModel),
(ElectraConfig, ElectraModel),
(ReformerConfig, ReformerModel),
(FunnelConfig, FunnelModel),
(FunnelConfig, (FunnelModel, FunnelBaseModel)),
(LxmertConfig, LxmertModel),
(BertGenerationConfig, BertGenerationEncoder),
(DebertaConfig, DebertaModel),
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
TFFlaubertWithLMHeadModel,
)
from ..funnel.modeling_tf_funnel import (
TFFunnelBaseModel,
TFFunnelForMaskedLM,
TFFunnelForMultipleChoice,
TFFunnelForPreTraining,
Expand Down Expand Up @@ -242,7 +243,7 @@
(XLMConfig, TFXLMModel),
(CTRLConfig, TFCTRLModel),
(ElectraConfig, TFElectraModel),
(FunnelConfig, TFFunnelModel),
(FunnelConfig, (TFFunnelModel, TFFunnelBaseModel)),
(DPRConfig, TFDPRQuestionEncoder),
(MPNetConfig, TFMPNetModel),
(BartConfig, TFBartModel),
Expand Down
3 changes: 2 additions & 1 deletion tests/test_modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device

from .test_configuration_common import ConfigTester
Expand Down Expand Up @@ -234,7 +235,7 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)

if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)
Expand Down
32 changes: 28 additions & 4 deletions tests/test_modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import copy
import tempfile
import unittest

from transformers import is_torch_available
Expand Down Expand Up @@ -46,6 +47,8 @@
BertForSequenceClassification,
BertForTokenClassification,
BertModel,
FunnelBaseModel,
FunnelModel,
GPT2Config,
GPT2LMHeadModel,
RobertaForMaskedLM,
Expand Down Expand Up @@ -218,6 +221,21 @@ def test_from_identifier_from_model_type(self):
self.assertEqual(model.num_parameters(), 14410)
self.assertEqual(model.num_parameters(only_trainable=True), 14410)

def test_from_pretrained_with_tuple_values(self):
# For the auto model mapping, FunnelConfig has two models: FunnelModel and FunnelBaseModel
model = AutoModel.from_pretrained("sgugger/funnel-random-tiny")
self.assertIsInstance(model, FunnelModel)

config = copy.deepcopy(model.config)
config.architectures = ["FunnelBaseModel"]
model = AutoModel.from_config(config)
self.assertIsInstance(model, FunnelBaseModel)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model = AutoModel.from_pretrained(tmp_dir)
self.assertIsInstance(model, FunnelBaseModel)

def test_parents_and_children_in_mappings(self):
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
# by the parents and will return the wrong configuration type when using auto models
Expand All @@ -242,6 +260,12 @@ def test_parents_and_children_in_mappings(self):
assert not issubclass(
child_config, parent_config
), f"{child_config.__name__} is child of {parent_config.__name__}"
assert not issubclass(
child_model, parent_model
), f"{child_config.__name__} is child of {parent_config.__name__}"

# Tuplify child_model and parent_model since some of them could be tuples.
if not isinstance(child_model, (list, tuple)):
child_model = (child_model,)
if not isinstance(parent_model, (list, tuple)):
parent_model = (parent_model,)

for child, parent in [(a, b) for a in child_model for b in parent_model]:
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
3 changes: 2 additions & 1 deletion tests/test_modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device

from .test_configuration_common import ConfigTester
Expand Down Expand Up @@ -444,7 +445,7 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)

if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from tests.test_modeling_common import floats_tensor
from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.models.big_bird.tokenization_big_bird import BigBirdTokenizer
from transformers.testing_utils import require_torch, slow, torch_device

Expand Down Expand Up @@ -458,7 +459,7 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)

if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)
Expand Down
27 changes: 14 additions & 13 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from transformers import is_torch_available
from transformers.file_utils import WEIGHTS_NAME
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device


Expand Down Expand Up @@ -79,7 +80,7 @@ class ModelTesterMixin:

def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict = {
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
if isinstance(v, torch.Tensor) and v.ndim > 1
Expand All @@ -88,28 +89,28 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
}

if return_labels:
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
elif model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
inputs_dict["start_positions"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
inputs_dict["end_positions"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
elif model_class in [
*MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(),
*MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(),
*MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.values(),
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
]:
inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
elif model_class in [
*MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
*MODEL_FOR_CAUSAL_LM_MAPPING.values(),
*MODEL_FOR_MASKED_LM_MAPPING.values(),
*MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
]:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
Expand Down Expand Up @@ -229,7 +230,7 @@ def test_training(self):
config.return_dict = True

for model_class in self.all_model_classes:
if model_class in MODEL_MAPPING.values():
if model_class in get_values(MODEL_MAPPING):
continue
model = model_class(config)
model.to(torch_device)
Expand All @@ -248,7 +249,7 @@ def test_training_gradient_checkpointing(self):
config.return_dict = True

for model_class in self.all_model_classes:
if model_class in MODEL_MAPPING.values():
if model_class in get_values(MODEL_MAPPING):
continue
model = model_class(config)
model.to(torch_device)
Expand Down Expand Up @@ -312,7 +313,7 @@ def test_attention_outputs(self):
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
# Question Answering model returns start_logits and end_logits
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned
Expand Down
3 changes: 2 additions & 1 deletion tests/test_modeling_convbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from tests.test_modeling_common import floats_tensor
from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device

from .test_configuration_common import ConfigTester
Expand Down Expand Up @@ -352,7 +353,7 @@ def test_attention_outputs(self):
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
# Question Answering model returns start_logits and end_logits
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned
Expand Down
3 changes: 2 additions & 1 deletion tests/test_modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device

from .test_configuration_common import ConfigTester
Expand Down Expand Up @@ -292,7 +293,7 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)

if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)
Expand Down
Loading