Skip to content

Commit

Permalink
from_pretrained: check that the pretrained model is for the right mo…
Browse files Browse the repository at this point in the history
…del architecture (#10586)

* Added check to ensure model name passed to from_pretrained and model are the same

* Added test to check from_pretrained throws assert error when passed an incompatiable model name

* Modified assert in from_pretrained with f-strings. Modified test to ensure desired assert message is being generated

* Added check to ensure config and model has model_type

* Fix FlauBERT heads

Co-authored-by: vimarsh chaturvedi <vimarsh chaturvedi>
Co-authored-by: Stas Bekman <stas@stason.org>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
  • Loading branch information
3 people authored Mar 18, 2021
1 parent 4f3e93c commit 094afa5
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
"""
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if config_dict.get("model_type", False) and hasattr(cls, "model_type"):
assert (
config_dict["model_type"] == cls.model_type
), f"You tried to initiate a model of type '{cls.model_type}' with a pretrained model of type '{config_dict['model_type']}'"

return cls.from_dict(config_dict, **kwargs)

@classmethod
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/flaubert/modeling_tf_flaubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,8 @@ def __init__(self, config, *inputs, **kwargs):
FLAUBERT_START_DOCSTRING,
)
class TFFlaubertForTokenClassification(TFXLMForTokenClassification):
config_class = FlaubertConfig

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFFlaubertMainLayer(config, name="transformer")
Expand All @@ -945,6 +947,8 @@ def __init__(self, config, *inputs, **kwargs):
FLAUBERT_START_DOCSTRING,
)
class TFFlaubertForMultipleChoice(TFXLMForMultipleChoice):
config_class = FlaubertConfig

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFFlaubertMainLayer(config, name="transformer")
12 changes: 12 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
BertModel,
PretrainedConfig,
PreTrainedModel,
T5ForConditionalGeneration,
)


Expand All @@ -58,6 +59,9 @@ def _config_zero_init(config):
return configs_no_init


TINY_T5 = "patrickvonplaten/t5-tiny-random"


@require_torch
class ModelTesterMixin:

Expand Down Expand Up @@ -1284,3 +1288,11 @@ def test_model_from_pretrained(self):
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(model.config, config)

def test_model_from_pretrained_with_different_pretrained_model_name(self):
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
self.assertIsNotNone(model)

with self.assertRaises(Exception) as context:
BertModel.from_pretrained(TINY_T5)
self.assertTrue("You tried to initiate a model of type" in str(context.exception))

0 comments on commit 094afa5

Please sign in to comment.