Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add DOS net #3452

Merged
merged 32 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
159c4c6
feat: add dos net
anyangml Mar 12, 2024
1f8c74c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2024
f9b0b06
feat: add dp
anyangml Mar 12, 2024
a990122
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2024
5ce0e71
fix: serialize
anyangml Mar 13, 2024
9e396e2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
62f4150
fix: dim_out serialize
anyangml Mar 13, 2024
40ee0f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
812d563
fix: UTs
anyangml Mar 13, 2024
407eb48
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
8ae55d3
fix: UTs
anyangml Mar 13, 2024
ac5f1be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
fc9d01a
feat: add UTs
anyangml Mar 13, 2024
81a16b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
ef008d5
Merge branch 'devel' into feat/dos
anyangml Mar 13, 2024
8a7c250
feat: add training
anyangml Mar 13, 2024
e212776
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
4b985b4
fix: hack consistency UT
anyangml Mar 14, 2024
3825274
Merge branch 'devel' into feat/dos
anyangml Mar 14, 2024
a61f462
fix: remove UT hack
Mar 14, 2024
e5b6cf0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
06528f4
fix: precommit
Mar 14, 2024
6a1f995
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
d7ecbba
fix: UTs
anyangml Mar 14, 2024
3916f5e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
9f3b47d
Merge branch 'devel' into feat/dos
anyangml Mar 14, 2024
71f24ff
fix: update tf UTs
anyangml Mar 14, 2024
ad1d5ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
73313f0
fix: deep test
anyangml Mar 15, 2024
5cc34df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 15, 2024
60315ee
Merge branch 'devel' into feat/dos
anyangml Mar 15, 2024
784d7b9
fix: address comments
anyangml Mar 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/dpmodel/fitting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from .dipole_fitting import (
DipoleFitting,
)
from .dos_fitting import (
DOSFittingNet,
)
from .ener_fitting import (
EnergyFittingNet,
)
Expand All @@ -21,4 +24,5 @@
"DipoleFitting",
"EnergyFittingNet",
"PolarFitting",
"DOSFittingNet",
]
93 changes: 93 additions & 0 deletions deepmd/dpmodel/fitting/dos_fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
TYPE_CHECKING,
List,
Optional,
Union,
)

import numpy as np

from deepmd.dpmodel.common import (
DEFAULT_PRECISION,
)
from deepmd.dpmodel.fitting.invar_fitting import (
InvarFitting,
)

if TYPE_CHECKING:
from deepmd.dpmodel.fitting.general_fitting import (

Check warning on line 20 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L20

Added line #L20 was not covered by tests
GeneralFitting,
)

from deepmd.utils.version import (
check_version_compatibility,
)


@InvarFitting.register("dos")
class DOSFittingNet(InvarFitting):
def __init__(
self,
ntypes: int,
dim_descrpt: int,
numb_dos: int = 300,
neuron: List[int] = [120, 120, 120],
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
bias_dos: Optional[np.ndarray] = None,
rcond: Optional[float] = None,
trainable: Union[bool, List[bool]] = True,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
mixed_types: bool = False,
exclude_types: List[int] = [],
# not used
seed: Optional[int] = None,
):
if bias_dos is not None:
self.bias_dos = bias_dos

Check warning on line 51 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L50-L51

Added lines #L50 - L51 were not covered by tests
else:
self.bias_dos = np.zeros((ntypes, numb_dos), dtype=float)
super().__init__(

Check warning on line 54 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L53-L54

Added lines #L53 - L54 were not covered by tests
var_name="dos",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
dim_out=numb_dos,
neuron=neuron,
resnet_dt=resnet_dt,
bias_atom=bias_dos,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
rcond=rcond,
trainable=trainable,
activation_function=activation_function,
precision=precision,
mixed_types=mixed_types,
exclude_types=exclude_types,
)

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("var_name")
data["numb_dos"] = data.pop("dim_out")
data.pop("tot_ener_zero")
data.pop("layer_name")
data.pop("use_aparam_as_mask")
data.pop("spin")
data.pop("atom_ener")
return super().deserialize(data)

Check warning on line 83 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L74-L83

Added lines #L74 - L83 were not covered by tests

def serialize(self) -> dict:
"""Serialize the fitting to dict."""
dd = {

Check warning on line 87 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L87

Added line #L87 was not covered by tests
**super().serialize(),
"type": "dos",
}
dd["@variables"]["bias_atom_e"] = self.bias_atom_e

Check warning on line 91 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L91

Added line #L91 was not covered by tests

return dd

Check warning on line 93 in deepmd/dpmodel/fitting/dos_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dos_fitting.py#L93

Added line #L93 was not covered by tests
9 changes: 8 additions & 1 deletion deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
The dimension of the input descriptor.
neuron
Number of neurons :math:`N` in each hidden layer of the fitting net
bias_atom_e
Average enery per atom for each element.
resnet_dt
Time-step `dt` in the resnet construction:
:math:`y = x + dt * \phi (Wx + b)`
Expand Down Expand Up @@ -85,6 +87,7 @@
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
bias_atom_e: Optional[np.ndarray] = None,
anyangml marked this conversation as resolved.
Show resolved Hide resolved
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
trainable: Optional[List[bool]] = None,
Expand Down Expand Up @@ -125,7 +128,11 @@

net_dim_out = self._net_out_dim()
# init constants
self.bias_atom_e = np.zeros([self.ntypes, net_dim_out])
if bias_atom_e is None:
self.bias_atom_e = np.zeros([self.ntypes, net_dim_out])

Check warning on line 132 in deepmd/dpmodel/fitting/general_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/general_fitting.py#L131-L132

Added lines #L131 - L132 were not covered by tests
else:
assert bias_atom_e.shape == (self.ntypes, net_dim_out)
self.bias_atom_e = bias_atom_e

Check warning on line 135 in deepmd/dpmodel/fitting/general_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/general_fitting.py#L134-L135

Added lines #L134 - L135 were not covered by tests
if self.numb_fparam > 0:
self.fparam_avg = np.zeros(self.numb_fparam)
self.fparam_inv_std = np.ones(self.numb_fparam)
Expand Down
6 changes: 5 additions & 1 deletion deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class InvarFitting(GeneralFitting):
Number of atomic parameter
rcond
The condition number for the regression of atomic energy.
bias_atom
Bias for each element.
tot_ener_zero
Force the total energy to zero. Useful for the charge fitting.
trainable
Expand Down Expand Up @@ -117,10 +119,11 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
bias_atom: Optional[np.ndarray] = None,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
trainable: Optional[List[bool]] = None,
atom_ener: Optional[List[float]] = [],
atom_ener: Optional[List[float]] = None,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
layer_name: Optional[List[Optional[str]]] = None,
Expand Down Expand Up @@ -152,6 +155,7 @@ def __init__(
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
rcond=rcond,
bias_atom_e=bias_atom,
tot_ener_zero=tot_ener_zero,
trainable=trainable,
activation_function=activation_function,
Expand Down
80 changes: 80 additions & 0 deletions deepmd/pt/model/model/dos_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (

Check warning on line 2 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L2

Added line #L2 was not covered by tests
Dict,
Optional,
)

import torch

Check warning on line 7 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L7

Added line #L7 was not covered by tests

from .dp_model import (

Check warning on line 9 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L9

Added line #L9 was not covered by tests
DPModel,
)
Comment on lines +9 to +11

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.pt.model.model.dp_model
begins an import cycle.


class DOSModel(DPModel):
model_type = "dos"

Check warning on line 15 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L14-L15

Added lines #L14 - L15 were not covered by tests

def __init__(

Check warning on line 17 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L17

Added line #L17 was not covered by tests
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)

Check warning on line 22 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L22

Added line #L22 was not covered by tests

def forward(

Check warning on line 24 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L24

Added line #L24 was not covered by tests
self,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
model_ret = self.forward_common(

Check warning on line 33 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L33

Added line #L33 was not covered by tests
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.get_fitting_net() is not None:
model_predict = {}
model_predict["atom_dos"] = model_ret["dos"]
model_predict["dos"] = model_ret["dos_redu"]

Check warning on line 44 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L41-L44

Added lines #L41 - L44 were not covered by tests

if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]

Check warning on line 47 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L46-L47

Added lines #L46 - L47 were not covered by tests
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
return model_predict

Check warning on line 51 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L49-L51

Added lines #L49 - L51 were not covered by tests

@torch.jit.export
def forward_lower(

Check warning on line 54 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L53-L54

Added lines #L53 - L54 were not covered by tests
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
):
model_ret = self.forward_common_lower(

Check warning on line 64 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L64

Added line #L64 was not covered by tests
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.get_fitting_net() is not None:
model_predict = {}
model_predict["atom_dos"] = model_ret["dos"]
model_predict["dos"] = model_ret["energy_redu"]

Check warning on line 76 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L73-L76

Added lines #L73 - L76 were not covered by tests

else:
model_predict = model_ret
return model_predict

Check warning on line 80 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L79-L80

Added lines #L79 - L80 were not covered by tests
8 changes: 8 additions & 0 deletions deepmd/pt/model/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from deepmd.pt.model.task.dipole import (
DipoleFittingNet,
)
from deepmd.pt.model.task.dos import (

Check warning on line 21 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L21

Added line #L21 was not covered by tests
DOSFittingNet,
)
from deepmd.pt.model.task.ener import (
EnergyFittingNet,
EnergyFittingNetDirect,
Expand Down Expand Up @@ -45,6 +48,9 @@
from deepmd.pt.model.model.dipole_model import (
DipoleModel,
)
from deepmd.pt.model.model.dos_model import (

Check warning on line 51 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L51

Added line #L51 was not covered by tests
DOSModel,
)
Comment on lines +51 to +53

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.pt.model.model.dos_model
begins an import cycle.
from deepmd.pt.model.model.ener_model import (
EnergyModel,
)
Expand All @@ -68,6 +74,8 @@
cls = DipoleModel
elif isinstance(fitting, PolarFittingNet):
cls = PolarModel
elif isinstance(fitting, DOSFittingNet):
cls = DOSModel

Check warning on line 78 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L77-L78

Added lines #L77 - L78 were not covered by tests
# else: unknown fitting type, fall back to DPModel
return super().__new__(cls)

Expand Down
109 changes: 109 additions & 0 deletions deepmd/pt/model/task/dos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import logging
from typing import (

Check warning on line 4 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L2-L4

Added lines #L2 - L4 were not covered by tests
List,
Optional,
Union,
)

import torch

Check warning on line 10 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L10

Added line #L10 was not covered by tests

from deepmd.pt.model.task.ener import (

Check warning on line 12 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L12

Added line #L12 was not covered by tests
InvarFitting,
)
from deepmd.pt.model.task.fitting import (

Check warning on line 15 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L15

Added line #L15 was not covered by tests
Fitting,
)
from deepmd.pt.utils import (

Check warning on line 18 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L18

Added line #L18 was not covered by tests
env,
)
from deepmd.pt.utils.env import (

Check warning on line 21 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L21

Added line #L21 was not covered by tests
DEFAULT_PRECISION,
)
from deepmd.pt.utils.utils import (

Check warning on line 24 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L24

Added line #L24 was not covered by tests
to_numpy_array,
)
from deepmd.utils.version import (

Check warning on line 27 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L27

Added line #L27 was not covered by tests
check_version_compatibility,
)

dtype = env.GLOBAL_PT_FLOAT_PRECISION
device = env.DEVICE

Check warning on line 32 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L31-L32

Added lines #L31 - L32 were not covered by tests

log = logging.getLogger(__name__)

Check warning on line 34 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L34

Added line #L34 was not covered by tests


@Fitting.register("dos")
class DOSFittingNet(InvarFitting):
def __init__(

Check warning on line 39 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L37-L39

Added lines #L37 - L39 were not covered by tests
self,
ntypes: int,
dim_descrpt: int,
numb_dos: int = 300,
neuron: List[int] = [128, 128, 128],
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
rcond: Optional[float] = None,
bias_dos: Optional[torch.Tensor] = None,
trainable: Union[bool, List[bool]] = True,
seed: Optional[int] = None,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
exclude_types: List[int] = [],
mixed_types: bool = True,
):
if bias_dos is not None:
self.bias_dos = bias_dos

Check warning on line 58 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L57-L58

Added lines #L57 - L58 were not covered by tests
else:
self.bias_dos = torch.zeros((ntypes, numb_dos), dtype=float)
super().__init__(

Check warning on line 61 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L60-L61

Added lines #L60 - L61 were not covered by tests
var_name="dos",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
dim_out=numb_dos,
neuron=neuron,
bias_atom_e=bias_dos,
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
activation_function=activation_function,
precision=precision,
mixed_types=mixed_types,
rcond=rcond,
seed=seed,
exclude_types=exclude_types,
trainable=trainable,
)

@classmethod
def deserialize(cls, data: dict) -> "DOSFittingNet":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("var_name")
data.pop("tot_ener_zero")
data.pop("layer_name")
data.pop("use_aparam_as_mask")
data.pop("spin")
data.pop("atom_ener")
data["numb_dos"] = data.pop("dim_out")
obj = super().deserialize(data)

Check warning on line 92 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L80-L92

Added lines #L80 - L92 were not covered by tests

return obj

Check warning on line 94 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L94

Added line #L94 was not covered by tests

def serialize(self) -> dict:

Check warning on line 96 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L96

Added line #L96 was not covered by tests
"""Serialize the fitting to dict."""
# dd = super(InvarFitting, self).serialize()
dd = {

Check warning on line 99 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L99

Added line #L99 was not covered by tests
**InvarFitting.serialize(self),
"type": "dos",
"dim_out": self.dim_out,
}
dd["@variables"]["bias_atom_e"] = to_numpy_array(self.bias_atom_e)

Check warning on line 104 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L104

Added line #L104 was not covered by tests

return dd

Check warning on line 106 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L106

Added line #L106 was not covered by tests

# make jit happy with torch 2.0.0
exclude_types: List[int]

Check warning on line 109 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L109

Added line #L109 was not covered by tests
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@

def __setitem__(self, key, value):
if key in ["bias_atom_e"]:
value = value.view([self.ntypes, self._net_out_dim()])
value = value.view([self.ntypes, -1])

Check warning on line 406 in deepmd/pt/model/task/fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/fitting.py#L406

Added line #L406 was not covered by tests
self.bias_atom_e = value
elif key in ["fparam_avg"]:
self.fparam_avg = value
Expand Down
Loading
Loading