diff --git a/src/transformers/adapters/wrappers/configuration.py b/src/transformers/adapters/wrappers/configuration.py index 47656a667..3ffb3f22d 100644 --- a/src/transformers/adapters/wrappers/configuration.py +++ b/src/transformers/adapters/wrappers/configuration.py @@ -1,5 +1,3 @@ -import types - from ...configuration_utils import PretrainedConfig from ...models.encoder_decoder.configuration_encoder_decoder import EncoderDecoderConfig from ..configuration import ModelAdaptersConfig @@ -40,21 +38,6 @@ } -def _to_dict_new(self): - output = self._to_dict_original() - if hasattr(self, "adapters") and not isinstance(output["adapters"], dict): - output["adapters"] = self.adapters.to_dict() - if "custom_heads" in output.keys(): - del output["custom_heads"] - - # delete handles to overriden methods - del output["to_dict"] - del output["_to_dict_original"] - del output["is_adaptable"] - - return output - - def wrap_config(config: PretrainedConfig) -> PretrainedConfig: """ Makes required changes to a model config class to allow usage with adapters. @@ -86,11 +69,6 @@ def wrap_config(config: PretrainedConfig) -> PretrainedConfig: if key not in config.attribute_map: config.attribute_map[key] = value - # Override to_dict() to add adapters - if not hasattr(config, "_to_dict_original"): - config._to_dict_original = config.to_dict - config.to_dict = types.MethodType(_to_dict_new, config) - # Ensure custom_heads attribute is present if not hasattr(config, "custom_heads"): config.custom_heads = {} diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index f66b5734b..2073363cc 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -771,6 +771,16 @@ def to_diff_dict(self) -> Dict[str, Any]: return serializable_config_dict + def adapters_to_dict(self, output): + # Adapter-specific changes + if hasattr(self, "adapters") and not isinstance(output["adapters"], dict): + output["adapters"] = self.adapters.to_dict() + if "custom_heads" in output: + del output["custom_heads"] + if "is_adaptable" in output: + del output["is_adaptable"] + return output + def to_dict(self) -> Dict[str, Any]: """ Serializes this instance to a Python dictionary. @@ -789,6 +799,8 @@ def to_dict(self) -> Dict[str, Any]: self.dict_torch_dtype_to_str(output) + self.adapters_to_dict(output) + return output def to_json_string(self, use_diff: bool = True) -> str: diff --git a/src/transformers/models/encoder_decoder/configuration_encoder_decoder.py b/src/transformers/models/encoder_decoder/configuration_encoder_decoder.py index 1fca8a10f..5d562a61d 100644 --- a/src/transformers/models/encoder_decoder/configuration_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/configuration_encoder_decoder.py @@ -116,4 +116,7 @@ def to_dict(self): output["encoder"] = self.encoder.to_dict() output["decoder"] = self.decoder.to_dict() output["model_type"] = self.__class__.model_type + + self.adapters_to_dict(output) + return output