From 2ad5db7bf76d4b27f1091119f46460f0832468e9 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 22 Feb 2024 12:17:55 +0800 Subject: [PATCH] fix: add serialize, deserialize --- deepmd/tf/fit/dipole.py | 56 +++++++++++++++++++ .../tests/consistent/fitting/test_dipole.py | 1 + 2 files changed, 57 insertions(+) diff --git a/deepmd/tf/fit/dipole.py b/deepmd/tf/fit/dipole.py index b3916aa200..67dc31d14d 100644 --- a/deepmd/tf/fit/dipole.py +++ b/deepmd/tf/fit/dipole.py @@ -89,6 +89,7 @@ def __init__( self.seed = seed self.uniform_seed = uniform_seed self.seed_shift = one_layer_rand_seed_shift() + self.activation_function_name = activation_function self.fitting_activation_fn = get_activation_func(activation_function) self.fitting_precision = get_precision(precision) self.dim_rot_mat_1 = embedding_width @@ -333,3 +334,58 @@ def get_loss(self, loss: dict, lr) -> Loss: tensor_size=3, label_name="dipole", ) + + def serialize(self, suffix: str) -> dict: + """Serialize the model. + Returns + ------- + dict + The serialized data + """ + data = { + "var_name": "energy", + "ntypes": self.ntypes, + "dim_descrpt": self.dim_descrpt, + # very bad design: type embedding is not passed to the class + # TODO: refactor the class + "distinguish_types": True, + "dim_out": 3, + "neuron": self.n_neuron, + "resnet_dt": self.resnet_dt, + "activation_function": self.activation_function_name, + "precision": self.fitting_precision.name, + "exclude_types": [], + "nets": self.serialize_network( + ntypes=self.ntypes, + # TODO: consider type embeddings + ndim=self.dim_rot_mat_1, + in_dim=self.dim_descrpt, + neuron=self.n_neuron, + activation_function=self.activation_function_name, + resnet_dt=self.resnet_dt, + variables=self.fitting_net_variables, + suffix=suffix, + ) + } + return data + + @classmethod + def deserialize(cls, data: dict, suffix: str): + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Model + The deserialized model + """ + fitting = cls(**data) + fitting.fitting_net_variables = cls.deserialize_network( + data["nets"], + suffix=suffix, + ) + return fitting diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index e0a07e0c52..966127245f 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -119,6 +119,7 @@ def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]: self.inputs.ravel(), self.natoms, self.atype, + None, suffix, )