From 2e5973f68837e04396e4400f741b4e92d3f4e1a8 Mon Sep 17 00:00:00 2001 From: calpt Date: Wed, 24 Aug 2022 12:23:52 +0200 Subject: [PATCH 1/2] Fix model serialization using `torch.save()` & `load()` --- .../adapters/wrappers/configuration.py | 20 ------------------- src/transformers/configuration_utils.py | 8 ++++++++ 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/src/transformers/adapters/wrappers/configuration.py b/src/transformers/adapters/wrappers/configuration.py index 47656a6675..23be3b84a1 100644 --- a/src/transformers/adapters/wrappers/configuration.py +++ b/src/transformers/adapters/wrappers/configuration.py @@ -40,21 +40,6 @@ } -def _to_dict_new(self): - output = self._to_dict_original() - if hasattr(self, "adapters") and not isinstance(output["adapters"], dict): - output["adapters"] = self.adapters.to_dict() - if "custom_heads" in output.keys(): - del output["custom_heads"] - - # delete handles to overriden methods - del output["to_dict"] - del output["_to_dict_original"] - del output["is_adaptable"] - - return output - - def wrap_config(config: PretrainedConfig) -> PretrainedConfig: """ Makes required changes to a model config class to allow usage with adapters. @@ -86,11 +71,6 @@ def wrap_config(config: PretrainedConfig) -> PretrainedConfig: if key not in config.attribute_map: config.attribute_map[key] = value - # Override to_dict() to add adapters - if not hasattr(config, "_to_dict_original"): - config._to_dict_original = config.to_dict - config.to_dict = types.MethodType(_to_dict_new, config) - # Ensure custom_heads attribute is present if not hasattr(config, "custom_heads"): config.custom_heads = {} diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index f66b5734bd..390eba758b 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -789,6 +789,14 @@ def to_dict(self) -> Dict[str, Any]: self.dict_torch_dtype_to_str(output) + # Adapter-specific changes + if hasattr(self, "adapters") and not isinstance(output["adapters"], dict): + output["adapters"] = self.adapters.to_dict() + if "custom_heads" in output: + del output["custom_heads"] + if "is_adaptable" in output: + del output["is_adaptable"] + return output def to_json_string(self, use_diff: bool = True) -> str: From 3186753e7a0ab838900a3f5d5d574de7f637c640 Mon Sep 17 00:00:00 2001 From: calpt Date: Wed, 24 Aug 2022 14:40:50 +0200 Subject: [PATCH 2/2] test fix --- .../adapters/wrappers/configuration.py | 2 -- src/transformers/configuration_utils.py | 18 +++++++++++------- .../configuration_encoder_decoder.py | 3 +++ 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/transformers/adapters/wrappers/configuration.py b/src/transformers/adapters/wrappers/configuration.py index 23be3b84a1..3ffb3f22de 100644 --- a/src/transformers/adapters/wrappers/configuration.py +++ b/src/transformers/adapters/wrappers/configuration.py @@ -1,5 +1,3 @@ -import types - from ...configuration_utils import PretrainedConfig from ...models.encoder_decoder.configuration_encoder_decoder import EncoderDecoderConfig from ..configuration import ModelAdaptersConfig diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 390eba758b..2073363cc2 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -771,6 +771,16 @@ def to_diff_dict(self) -> Dict[str, Any]: return serializable_config_dict + def adapters_to_dict(self, output): + # Adapter-specific changes + if hasattr(self, "adapters") and not isinstance(output["adapters"], dict): + output["adapters"] = self.adapters.to_dict() + if "custom_heads" in output: + del output["custom_heads"] + if "is_adaptable" in output: + del output["is_adaptable"] + return output + def to_dict(self) -> Dict[str, Any]: """ Serializes this instance to a Python dictionary. @@ -789,13 +799,7 @@ def to_dict(self) -> Dict[str, Any]: self.dict_torch_dtype_to_str(output) - # Adapter-specific changes - if hasattr(self, "adapters") and not isinstance(output["adapters"], dict): - output["adapters"] = self.adapters.to_dict() - if "custom_heads" in output: - del output["custom_heads"] - if "is_adaptable" in output: - del output["is_adaptable"] + self.adapters_to_dict(output) return output diff --git a/src/transformers/models/encoder_decoder/configuration_encoder_decoder.py b/src/transformers/models/encoder_decoder/configuration_encoder_decoder.py index 1fca8a10f7..5d562a61de 100644 --- a/src/transformers/models/encoder_decoder/configuration_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/configuration_encoder_decoder.py @@ -116,4 +116,7 @@ def to_dict(self): output["encoder"] = self.encoder.to_dict() output["decoder"] = self.decoder.to_dict() output["model_type"] = self.__class__.model_type + + self.adapters_to_dict(output) + return output