Skip to content

Commit

Permalink
refactor the torch implementation of the fitting net
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Jan 29, 2024
1 parent 8afd47e commit 19069f3
Show file tree
Hide file tree
Showing 11 changed files with 574 additions and 128 deletions.
4 changes: 4 additions & 0 deletions deepmd/model_format/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from .env_mat import (
EnvMat,
)
from .fitting import (
InvarFitting,
)
from .network import (
EmbeddingNet,
FittingNet,
Expand Down Expand Up @@ -34,6 +37,7 @@
)

__all__ = [
"InvarFitting",
"DescrptSeA",
"EnvMat",
"make_multilayer_network",
Expand Down
11 changes: 5 additions & 6 deletions deepmd/model_format/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@

import numpy as np

from deepmd.model_format import (
FittingOutputDef,
OutputVariableDef,
fitting_check_output,
)

from .common import (
DEFAULT_PRECISION,
NativeOP,
Expand All @@ -22,6 +16,11 @@
FittingNet,
NetworkCollection,
)
from .output_def import (
FittingOutputDef,
OutputVariableDef,
fitting_check_output,
)


@fitting_check_output
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(

fitting_net["type"] = fitting_net.get("type", "ener")
if self.descriptor_type not in ["se_e2_a"]:
fitting_net["ntypes"] = 1
fitting_net["ntypes"] = self.descriptor.get_ntype()
else:
fitting_net["ntypes"] = self.descriptor.get_ntype()
fitting_net["use_tebd"] = False
Expand Down Expand Up @@ -165,5 +165,5 @@ def forward_atomic(
)
assert descriptor is not None
# energy, force
fit_ret = self.fitting_net(descriptor, atype, atype_tebd=None, rot_mat=rot_mat)
fit_ret = self.fitting_net(descriptor, atype, gr=rot_mat)
return fit_ret
Loading

0 comments on commit 19069f3

Please sign in to comment.