Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix model serialization using torch.save() & load() #406

Merged
merged 2 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions src/transformers/adapters/wrappers/configuration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import types

from ...configuration_utils import PretrainedConfig
from ...models.encoder_decoder.configuration_encoder_decoder import EncoderDecoderConfig
from ..configuration import ModelAdaptersConfig
Expand Down Expand Up @@ -40,21 +38,6 @@
}


def _to_dict_new(self):
output = self._to_dict_original()
if hasattr(self, "adapters") and not isinstance(output["adapters"], dict):
output["adapters"] = self.adapters.to_dict()
if "custom_heads" in output.keys():
del output["custom_heads"]

# delete handles to overriden methods
del output["to_dict"]
del output["_to_dict_original"]
del output["is_adaptable"]

return output


def wrap_config(config: PretrainedConfig) -> PretrainedConfig:
"""
Makes required changes to a model config class to allow usage with adapters.
Expand Down Expand Up @@ -86,11 +69,6 @@ def wrap_config(config: PretrainedConfig) -> PretrainedConfig:
if key not in config.attribute_map:
config.attribute_map[key] = value

# Override to_dict() to add adapters
if not hasattr(config, "_to_dict_original"):
config._to_dict_original = config.to_dict
config.to_dict = types.MethodType(_to_dict_new, config)

# Ensure custom_heads attribute is present
if not hasattr(config, "custom_heads"):
config.custom_heads = {}
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,16 @@ def to_diff_dict(self) -> Dict[str, Any]:

return serializable_config_dict

def adapters_to_dict(self, output):
# Adapter-specific changes
if hasattr(self, "adapters") and not isinstance(output["adapters"], dict):
output["adapters"] = self.adapters.to_dict()
if "custom_heads" in output:
del output["custom_heads"]
if "is_adaptable" in output:
del output["is_adaptable"]
return output

def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
Expand All @@ -789,6 +799,8 @@ def to_dict(self) -> Dict[str, Any]:

self.dict_torch_dtype_to_str(output)

self.adapters_to_dict(output)

return output

def to_json_string(self, use_diff: bool = True) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,7 @@ def to_dict(self):
output["encoder"] = self.encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type

self.adapters_to_dict(output)

return output