Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make Fittings pluginable #2541

Merged
merged 1 commit into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deepmd/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
)
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Expand Down Expand Up @@ -43,7 +44,7 @@ class Descriptor(PluginVariant):
__plugins = Plugin()

@staticmethod
def register(key: str) -> "Descriptor":
def register(key: str) -> Callable:
"""Register a descriptor plugin.

Parameters
Expand Down
4 changes: 4 additions & 0 deletions deepmd/fit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from .ener import (
EnerFitting,
)
from .fitting import (
Fitting,
)
from .polar import (
GlobalPolarFittingSeA,
PolarFittingSeA,
Expand All @@ -18,4 +21,5 @@
"DOSFitting",
"GlobalPolarFittingSeA",
"PolarFittingSeA",
"Fitting",
]
1 change: 1 addition & 0 deletions deepmd/fit/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)


@Fitting.register("dipole")
class DipoleFittingSeA(Fitting):
r"""Fit the atomic dipole with descriptor se_a.

Expand Down
1 change: 1 addition & 0 deletions deepmd/fit/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
log = logging.getLogger(__name__)


@Fitting.register("dos")
class DOSFitting(Fitting):
r"""Fitting the density of states (DOS) of the system.
The energy should be shifted by the fermi level.
Expand Down
1 change: 1 addition & 0 deletions deepmd/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
log = logging.getLogger(__name__)


@Fitting.register("ener")
class EnerFitting(Fitting):
r"""Fitting the energy of the system. The force and the virial can also be trained.

Expand Down
46 changes: 45 additions & 1 deletion deepmd/fit/fitting.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,53 @@
from typing import (
Callable,
)

from deepmd.env import (
tf,
)
from deepmd.utils import (
Plugin,
PluginVariant,
)


class Fitting(PluginVariant):
__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable:
"""Register a Fitting plugin.

Parameters
----------
key : str
the key of a Fitting

Returns
-------
Fitting
the registered Fitting

Examples
--------
>>> @Fitting.register("some_fitting")
class SomeFitting(Fitting):
pass
"""
return Fitting.__plugins.register(key)

def __new__(cls, *args, **kwargs):
if cls is Fitting:
try:
fitting_type = kwargs["type"]
except KeyError:
raise KeyError("the type of fitting should be set by `type`")
if fitting_type in Fitting.__plugins.plugins:
cls = Fitting.__plugins.plugins[fitting_type]
else:
raise RuntimeError("Unknown descriptor type: " + fitting_type)
return super().__new__(cls)

class Fitting:
@property
def precision(self) -> tf.DType:
"""Precision of fitting network."""
Expand Down
1 change: 1 addition & 0 deletions deepmd/fit/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)


@Fitting.register("polar")
class PolarFittingSeA(Fitting):
r"""Fit the atomic polarizability with descriptor se_a.

Expand Down
35 changes: 7 additions & 28 deletions deepmd/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@
tfv2,
)
from deepmd.fit import (
DipoleFittingSeA,
DOSFitting,
EnerFitting,
PolarFittingSeA,
Fitting,
)
from deepmd.loss import (
DOSLoss,
Expand Down Expand Up @@ -157,30 +154,13 @@ def _init_param(self, jdata):
self.descrpt = Descriptor(**descrpt_param)

# fitting net
def fitting_net_init(fitting_type_, descrpt_type_, params):
if fitting_type_ == "ener":
params["spin"] = self.spin
return EnerFitting(**params)
elif fitting_type_ == "dos":
return DOSFitting(**params)
elif fitting_type_ == "dipole":
return DipoleFittingSeA(**params)
elif fitting_type_ == "polar":
return PolarFittingSeA(**params)
# elif fitting_type_ == 'global_polar':
# if descrpt_type_ == 'se_e2_a':
# return GlobalPolarFittingSeA(**params)
# else:
# raise RuntimeError('fitting global_polar only supports descrptors: loc_frame and se_e2_a')
else:
raise RuntimeError("unknown fitting type " + fitting_type_)

if not self.multi_task_mode:
fitting_type = fitting_param.get("type", "ener")
self.fitting_type = fitting_type
fitting_param.pop("type", None)
fitting_param["descrpt"] = self.descrpt
self.fitting = fitting_net_init(fitting_type, descrpt_type, fitting_param)
if fitting_type == "ener":
fitting_param["spin"] = self.spin
self.fitting = Fitting(**fitting_param)
else:
self.fitting_dict = {}
self.fitting_type_dict = {}
Expand All @@ -189,11 +169,10 @@ def fitting_net_init(fitting_type_, descrpt_type_, params):
item_fitting_param = fitting_param[item]
item_fitting_type = item_fitting_param.get("type", "ener")
self.fitting_type_dict[item] = item_fitting_type
item_fitting_param.pop("type", None)
item_fitting_param["descrpt"] = self.descrpt
self.fitting_dict[item] = fitting_net_init(
item_fitting_type, descrpt_type, item_fitting_param
)
if item_fitting_type == "ener":
item_fitting_param["spin"] = self.spin
self.fitting_dict[item] = Fitting(**item_fitting_param)

# type embedding
padding = False
Expand Down