Skip to content

Commit

Permalink
Fix model serialization using torch.save() & load() (#406)
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Aug 24, 2022
1 parent 5635d50 commit 4a06075
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 22 deletions.
22 changes: 0 additions & 22 deletions src/transformers/adapters/wrappers/configuration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import types

from ...configuration_utils import PretrainedConfig
from ...models.encoder_decoder.configuration_encoder_decoder import EncoderDecoderConfig
from ..configuration import ModelAdaptersConfig
Expand Down Expand Up @@ -40,21 +38,6 @@
}


def _to_dict_new(self):
output = self._to_dict_original()
if hasattr(self, "adapters") and not isinstance(output["adapters"], dict):
output["adapters"] = self.adapters.to_dict()
if "custom_heads" in output.keys():
del output["custom_heads"]

# delete handles to overriden methods
del output["to_dict"]
del output["_to_dict_original"]
del output["is_adaptable"]

return output


def wrap_config(config: PretrainedConfig) -> PretrainedConfig:
"""
Makes required changes to a model config class to allow usage with adapters.
Expand Down Expand Up @@ -86,11 +69,6 @@ def wrap_config(config: PretrainedConfig) -> PretrainedConfig:
if key not in config.attribute_map:
config.attribute_map[key] = value

# Override to_dict() to add adapters
if not hasattr(config, "_to_dict_original"):
config._to_dict_original = config.to_dict
config.to_dict = types.MethodType(_to_dict_new, config)

# Ensure custom_heads attribute is present
if not hasattr(config, "custom_heads"):
config.custom_heads = {}
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,16 @@ def to_diff_dict(self) -> Dict[str, Any]:

return serializable_config_dict

def adapters_to_dict(self, output):
# Adapter-specific changes
if hasattr(self, "adapters") and not isinstance(output["adapters"], dict):
output["adapters"] = self.adapters.to_dict()
if "custom_heads" in output:
del output["custom_heads"]
if "is_adaptable" in output:
del output["is_adaptable"]
return output

def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
Expand All @@ -789,6 +799,8 @@ def to_dict(self) -> Dict[str, Any]:

self.dict_torch_dtype_to_str(output)

self.adapters_to_dict(output)

return output

def to_json_string(self, use_diff: bool = True) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,7 @@ def to_dict(self):
output["encoder"] = self.encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type

self.adapters_to_dict(output)

return output

0 comments on commit 4a06075

Please sign in to comment.