Skip to content

Commit

Permalink
fix(tf): fix model out_bias deserialize (#4350)
Browse files Browse the repository at this point in the history
per discussion

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced handling of model serialization and deserialization,
particularly for bias parameters.
- Updated output structure for the `PT` backend in the energy model
tests.

- **Bug Fixes**
- Improved logic for managing unsupported model configurations, ensuring
clearer error reporting.

- **Documentation**
- Updated method signatures to reflect changes in functionality for
model handling and testing.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Nov 14, 2024
1 parent 6e815a2 commit d7cf48c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
46 changes: 43 additions & 3 deletions deepmd/tf/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from deepmd.common import (
j_get_type,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.tf.descriptor.descriptor import (
Descriptor,
)
Expand Down Expand Up @@ -803,10 +806,34 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor":
-------
Descriptor
The deserialized descriptor
Raises
------
ValueError
If both fitting/@variables/bias_atom_e and @variables/out_bias are non-zero
"""
data = data.copy()
check_version_compatibility(data.pop("@version", 2), 2, 1)
descriptor = Descriptor.deserialize(data.pop("descriptor"), suffix=suffix)
if data["fitting"].get("@variables", {}).get("bias_atom_e") is not None:
# careful: copy each level and don't modify the input array,
# otherwise it will affect the original data
# deepcopy is not used for performance reasons
data["fitting"] = data["fitting"].copy()
data["fitting"]["@variables"] = data["fitting"]["@variables"].copy()
if (
int(np.any(data["fitting"]["@variables"]["bias_atom_e"]))
+ int(np.any(data["@variables"]["out_bias"]))
> 1
):
raise ValueError(
"fitting/@variables/bias_atom_e and @variables/out_bias should not be both non-zero"
)
data["fitting"]["@variables"]["bias_atom_e"] = data["fitting"][
"@variables"
]["bias_atom_e"] + data["@variables"]["out_bias"].reshape(
data["fitting"]["@variables"]["bias_atom_e"].shape
)
fitting = Fitting.deserialize(data.pop("fitting"), suffix=suffix)
# pass descriptor type embedding to model
if descriptor.explicit_ntypes:
Expand All @@ -815,8 +842,10 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor":
else:
type_embedding = None
# BEGINE not supported keys
data.pop("atom_exclude_types")
data.pop("pair_exclude_types")
if len(data.pop("atom_exclude_types")) > 0:
raise NotImplementedError("atom_exclude_types is not supported")
if len(data.pop("pair_exclude_types")) > 0:
raise NotImplementedError("pair_exclude_types is not supported")
data.pop("rcond", None)
data.pop("preset_out_bias", None)
data.pop("@variables", None)
Expand Down Expand Up @@ -853,6 +882,17 @@ def serialize(self, suffix: str = "") -> dict:

ntypes = len(self.get_type_map())
dict_fit = self.fitting.serialize(suffix=suffix)
if dict_fit.get("@variables", {}).get("bias_atom_e") is not None:
out_bias = dict_fit["@variables"]["bias_atom_e"].reshape(
[1, ntypes, dict_fit["dim_out"]]
)
dict_fit["@variables"]["bias_atom_e"] = np.zeros_like(
dict_fit["@variables"]["bias_atom_e"]
)
else:
out_bias = np.zeros(
[1, ntypes, dict_fit["dim_out"]], dtype=GLOBAL_NP_FLOAT_PRECISION
)
return {
"@class": "Model",
"type": "standard",
Expand All @@ -866,7 +906,7 @@ def serialize(self, suffix: str = "") -> dict:
"rcond": None,
"preset_out_bias": None,
"@variables": {
"out_bias": np.zeros([1, ntypes, dict_fit["dim_out"]]), # pylint: disable=no-explicit-dtype
"out_bias": out_bias,
"out_std": np.ones([1, ntypes, dict_fit["dim_out"]]), # pylint: disable=no-explicit-dtype
},
}
Expand Down
4 changes: 3 additions & 1 deletion source/tests/consistent/model/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def pass_data_to_cls(self, cls, data) -> Any:
if cls is EnergyModelDP:
return get_model_dp(data)
elif cls is EnergyModelPT:
return get_model_pt(data)
model = get_model_pt(data)
model.atomic_model.out_bias.uniform_()
return model
elif cls is EnergyModelJAX:
return get_model_jax(data)
return cls(**data, **self.additional_data)
Expand Down

0 comments on commit d7cf48c

Please sign in to comment.