From 06c7238c470be955e8dc342ab34fe3a761383e51 Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 10 Nov 2021 18:15:54 +0000 Subject: [PATCH 1/3] Experimenting with adding proper get_config() and from_config() methods --- src/transformers/modeling_tf_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index d46b905c4f7..d7fcb6e56ec 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -692,6 +692,13 @@ def __init__(self, config, *inputs, **kwargs): self.config = config self.name_or_path = config.name_or_path + def get_config(self): + return self.config + + @classmethod + def from_config(cls, config, **kwargs): + return cls._from_config(config, **kwargs) + @classmethod def _from_config(cls, config, **kwargs): """ From 07b522a07a3c74d14f98b157678b66615a79aa5f Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 10 Nov 2021 19:15:12 +0000 Subject: [PATCH 2/3] Adding a test for get/from config --- tests/test_modeling_tf_common.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 2448ae9c90a..6faf5268f02 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -163,6 +163,19 @@ def test_save_load(self): self.assert_outputs_same(after_outputs, outputs) + def test_save_load_config(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + outputs = model(self._prepare_for_class(inputs_dict, model_class)) + + new_model = model_class.from_config(model.get_config()) + new_model.set_weights(model.get_weights()) + after_outputs = model(self._prepare_for_class(inputs_dict, model_class)) + + self.assert_outputs_same(after_outputs, outputs) + @tooslow def test_graph_mode(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From 5a424c137a26056c6e6891eedc7539605f580be7 Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 10 Nov 2021 19:47:44 +0000 Subject: [PATCH 3/3] Fix test for get/from config --- tests/test_modeling_tf_common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 6faf5268f02..30a7052daff 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -171,8 +171,9 @@ def test_save_load_config(self): outputs = model(self._prepare_for_class(inputs_dict, model_class)) new_model = model_class.from_config(model.get_config()) + _ = new_model(self._prepare_for_class(inputs_dict, model_class)) # Build model new_model.set_weights(model.get_weights()) - after_outputs = model(self._prepare_for_class(inputs_dict, model_class)) + after_outputs = new_model(self._prepare_for_class(inputs_dict, model_class)) self.assert_outputs_same(after_outputs, outputs)