Skip to content

Commit

Permalink
fix: add serialize, deserialize
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Feb 22, 2024
1 parent fb7447c commit 2ad5db7
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
56 changes: 56 additions & 0 deletions deepmd/tf/fit/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
self.seed = seed
self.uniform_seed = uniform_seed
self.seed_shift = one_layer_rand_seed_shift()
self.activation_function_name = activation_function
self.fitting_activation_fn = get_activation_func(activation_function)
self.fitting_precision = get_precision(precision)
self.dim_rot_mat_1 = embedding_width
Expand Down Expand Up @@ -333,3 +334,58 @@ def get_loss(self, loss: dict, lr) -> Loss:
tensor_size=3,
label_name="dipole",
)

def serialize(self, suffix: str) -> dict:
"""Serialize the model.
Returns
-------
dict
The serialized data
"""
data = {
"var_name": "energy",
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
# very bad design: type embedding is not passed to the class
# TODO: refactor the class
"distinguish_types": True,
"dim_out": 3,
"neuron": self.n_neuron,
"resnet_dt": self.resnet_dt,
"activation_function": self.activation_function_name,
"precision": self.fitting_precision.name,
"exclude_types": [],
"nets": self.serialize_network(
ntypes=self.ntypes,
# TODO: consider type embeddings
ndim=self.dim_rot_mat_1,
in_dim=self.dim_descrpt,
neuron=self.n_neuron,
activation_function=self.activation_function_name,
resnet_dt=self.resnet_dt,
variables=self.fitting_net_variables,
suffix=suffix,
)
}
return data

@classmethod
def deserialize(cls, data: dict, suffix: str):
"""Deserialize the model.
Parameters
----------
data : dict
The serialized data
Returns
-------
Model
The deserialized model
"""
fitting = cls(**data)
fitting.fitting_net_variables = cls.deserialize_network(
data["nets"],
suffix=suffix,
)
return fitting
1 change: 1 addition & 0 deletions source/tests/consistent/fitting/test_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]:
self.inputs.ravel(),
self.natoms,
self.atype,
None,
suffix,
)

Expand Down

0 comments on commit 2ad5db7

Please sign in to comment.