From 19069f3f20d56fa46ef8a38d38d450befc9fb30c Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 29 Jan 2024 22:36:52 +0800 Subject: [PATCH] refactor the torch implementation of the fitting net --- deepmd/model_format/__init__.py | 4 + deepmd/model_format/fitting.py | 11 +- deepmd/pt/model/model/dp_atomic_model.py | 4 +- deepmd/pt/model/task/ener.py | 376 ++++++++++++++++++++--- deepmd/pt/model/task/fitting.py | 13 +- deepmd/pt/model/task/task.py | 18 +- deepmd/pt/utils/utils.py | 32 ++ source/tests/pt/test_ener_fitting.py | 162 ++++++++++ source/tests/pt/test_fitting_net.py | 24 +- source/tests/pt/test_model.py | 25 +- source/tests/pt/test_se_e2_a.py | 33 +- 11 files changed, 574 insertions(+), 128 deletions(-) create mode 100644 source/tests/pt/test_ener_fitting.py diff --git a/deepmd/model_format/__init__.py b/deepmd/model_format/__init__.py index 253bca3507..e15f73758e 100644 --- a/deepmd/model_format/__init__.py +++ b/deepmd/model_format/__init__.py @@ -7,6 +7,9 @@ from .env_mat import ( EnvMat, ) +from .fitting import ( + InvarFitting, +) from .network import ( EmbeddingNet, FittingNet, @@ -34,6 +37,7 @@ ) __all__ = [ + "InvarFitting", "DescrptSeA", "EnvMat", "make_multilayer_network", diff --git a/deepmd/model_format/fitting.py b/deepmd/model_format/fitting.py index b3195cd26e..8f79ae3491 100644 --- a/deepmd/model_format/fitting.py +++ b/deepmd/model_format/fitting.py @@ -8,12 +8,6 @@ import numpy as np -from deepmd.model_format import ( - FittingOutputDef, - OutputVariableDef, - fitting_check_output, -) - from .common import ( DEFAULT_PRECISION, NativeOP, @@ -22,6 +16,11 @@ FittingNet, NetworkCollection, ) +from .output_def import ( + FittingOutputDef, + OutputVariableDef, + fitting_check_output, +) @fitting_check_output diff --git a/deepmd/pt/model/model/dp_atomic_model.py b/deepmd/pt/model/model/dp_atomic_model.py index 853eacb875..245c0f3d3f 100644 --- a/deepmd/pt/model/model/dp_atomic_model.py +++ b/deepmd/pt/model/model/dp_atomic_model.py @@ -94,7 +94,7 @@ def __init__( fitting_net["type"] = fitting_net.get("type", "ener") if self.descriptor_type not in ["se_e2_a"]: - fitting_net["ntypes"] = 1 + fitting_net["ntypes"] = self.descriptor.get_ntype() else: fitting_net["ntypes"] = self.descriptor.get_ntype() fitting_net["use_tebd"] = False @@ -165,5 +165,5 @@ def forward_atomic( ) assert descriptor is not None # energy, force - fit_ret = self.fitting_net(descriptor, atype, atype_tebd=None, rot_mat=rot_mat) + fit_ret = self.fitting_net(descriptor, atype, gr=rot_mat) return fit_ret diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index 03043e2fcb..6b33491416 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy import logging from typing import ( + List, Optional, Tuple, ) +import numpy as np import torch from deepmd.model_format import ( @@ -12,6 +15,10 @@ OutputVariableDef, fitting_check_output, ) +from deepmd.pt.model.network.mlp import ( + FittingNet, + NetworkCollection, +) from deepmd.pt.model.network.network import ( ResidualDeep, ) @@ -21,19 +28,35 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, + PRECISION_DICT, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION +device = env.DEVICE -@Fitting.register("ener") @fitting_check_output -class EnergyFittingNet(Fitting): +class InvarFitting(Fitting): def __init__( self, - ntypes, - embedding_width, - neuron, - bias_atom_e, - resnet_dt=True, - use_tebd=True, + var_name: str, + ntypes: int, + dim_descrpt: int, + dim_out: int, + neuron: List[int] = [128, 128, 128], + bias_atom_e: Optional[torch.Tensor] = None, + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + distinguish_types: bool = False, **kwargs, ): """Construct a fitting net for energy. @@ -46,67 +69,325 @@ def __init__( - resnet_dt: Using time-step in the ResNet construction. """ super().__init__() + self.var_name = var_name self.ntypes = ntypes - self.embedding_width = embedding_width - self.use_tebd = use_tebd - if not use_tebd: - assert self.ntypes == len(bias_atom_e), "Element count mismatches!" - bias_atom_e = torch.tensor(bias_atom_e) + self.dim_descrpt = dim_descrpt + self.dim_out = dim_out + self.neuron = neuron + self.distinguish_types = distinguish_types + self.use_tebd = not self.distinguish_types + self.resnet_dt = resnet_dt + self.numb_fparam = numb_fparam + self.numb_aparam = numb_aparam + self.activation_function = activation_function + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + if bias_atom_e is None: + bias_atom_e = np.zeros([self.ntypes, self.dim_out]) + bias_atom_e = torch.tensor(bias_atom_e, dtype=self.prec, device=device) + bias_atom_e = bias_atom_e.view([self.ntypes, self.dim_out]) + if not self.use_tebd: + assert self.ntypes == bias_atom_e.shape[0], "Element count mismatches!" self.register_buffer("bias_atom_e", bias_atom_e) + # init constants + if self.numb_fparam > 0: + self.register_buffer( + "fparam_avg", + torch.zeros(self.numb_fparam, dtype=self.prec, device=device), + ) + self.register_buffer( + "fparam_inv_std", + torch.ones(self.numb_fparam, dtype=self.prec, device=device), + ) + else: + self.fparam_avg, self.fparam_inv_std = None, None + if self.numb_aparam > 0: + self.register_buffer( + "aparam_avg", + torch.zeros(self.numb_aparam, dtype=self.prec, device=device), + ) + self.register_buffer( + "aparam_inv_std", + torch.ones(self.numb_aparam, dtype=self.prec, device=device), + ) + else: + self.aparam_avg, self.aparam_inv_std = None, None - filter_layers = [] - for type_i in range(self.ntypes): - bias_type = 0.0 - one = ResidualDeep( - type_i, embedding_width, neuron, bias_type, resnet_dt=resnet_dt + in_dim = self.dim_descrpt + self.numb_fparam + self.numb_aparam + out_dim = 1 + + self.old_impl = kwargs.get("old_impl", False) + if self.old_impl: + filter_layers = [] + for type_i in range(self.ntypes): + bias_type = 0.0 + one = ResidualDeep( + type_i, + self.dim_descrpt, + self.neuron, + bias_type, + resnet_dt=self.resnet_dt, + ) + filter_layers.append(one) + self.filter_layers_old = torch.nn.ModuleList(filter_layers) + self.filter_layers = None + else: + self.filter_layers = NetworkCollection( + 1 if self.distinguish_types else 0, + self.ntypes, + network_type="fitting_network", + networks=[ + FittingNet( + in_dim, + out_dim, + self.neuron, + self.activation_function, + self.resnet_dt, + self.precision, + bias_out=True, + ) + for ii in range(self.ntypes if self.distinguish_types else 1) + ], ) - filter_layers.append(one) - self.filter_layers = torch.nn.ModuleList(filter_layers) + self.filter_layers_old = None + # very bad design... if "seed" in kwargs: logging.info("Set seed to %d in fitting net.", kwargs["seed"]) torch.manual_seed(kwargs["seed"]) - def output_def(self): + def output_def(self) -> FittingOutputDef: return FittingOutputDef( [ - OutputVariableDef("energy", [1], reduciable=True, differentiable=True), + OutputVariableDef( + self.var_name, [self.dim_out], reduciable=True, differentiable=True + ), ] ) + def __setitem__(self, key, value): + if key in ("bias_atom_e"): + # correct bias_atom_e shape. user may provide stupid shape + self.bias_atom_e = value + elif key in ("fparam_avg"): + self.fparam_avg = value + elif key in ("fparam_inv_std"): + self.fparam_inv_std = value + elif key in ("aparam_avg"): + self.aparam_avg = value + elif key in ("aparam_inv_std"): + self.aparam_inv_std = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("bias_atom_e"): + return self.bias_atom_e + elif key in ("fparam_avg"): + return self.fparam_avg + elif key in ("fparam_inv_std"): + return self.fparam_inv_std + elif key in ("aparam_avg"): + return self.aparam_avg + elif key in ("aparam_inv_std"): + return self.aparam_inv_std + else: + raise KeyError(key) + + def serialize(self) -> dict: + """Serialize the fitting to dict.""" + return { + "var_name": self.var_name, + "ntypes": self.ntypes, + "dim_descrpt": self.dim_descrpt, + "dim_out": self.dim_out, + "neuron": self.neuron, + "resnet_dt": self.resnet_dt, + "numb_fparam": self.numb_fparam, + "numb_aparam": self.numb_aparam, + "activation_function": self.activation_function, + "precision": self.precision, + "distinguish_types": self.distinguish_types, + "nets": self.filter_layers.serialize(), + "@variables": { + "bias_atom_e": to_numpy_array(self.bias_atom_e), + "fparam_avg": to_numpy_array(self.fparam_avg), + "fparam_inv_std": to_numpy_array(self.fparam_inv_std), + "aparam_avg": to_numpy_array(self.aparam_avg), + "aparam_inv_std": to_numpy_array(self.aparam_inv_std), + }, + # "rcond": self.rcond , + # "tot_ener_zero": self.tot_ener_zero , + # "trainable": self.trainable , + # "atom_ener": self.atom_ener , + # "layer_name": self.layer_name , + # "use_aparam_as_mask": self.use_aparam_as_mask , + # "spin": self.spin , + ## NOTICE: not supported by far + "rcond": None, + "tot_ener_zero": False, + "trainable": True, + "atom_ener": None, + "layer_name": None, + "use_aparam_as_mask": False, + "spin": None, + } + + @classmethod + def deserialize(cls, data: dict) -> "InvarFitting": + data = copy.deepcopy(data) + variables = data.pop("@variables") + nets = data.pop("nets") + obj = cls(**data) + for kk in variables.keys(): + obj[kk] = to_torch_tensor(variables[kk]) + obj.filter_layers = NetworkCollection.deserialize(nets) + return obj + + def _extend_f_avg_std(self, xx: torch.Tensor, nb: int) -> torch.Tensor: + return torch.tile(xx.view([1, self.numb_fparam]), [nb, 1]) + + def _extend_a_avg_std(self, xx: torch.Tensor, nb: int, nloc: int) -> torch.Tensor: + return torch.tile(xx.view([1, 1, self.numb_aparam]), [nb, nloc, 1]) + def forward( self, - inputs: torch.Tensor, + descriptor: torch.Tensor, atype: torch.Tensor, - atype_tebd: Optional[torch.Tensor] = None, - rot_mat: Optional[torch.Tensor] = None, + gr: Optional[torch.Tensor] = None, + g2: Optional[torch.Tensor] = None, + h2: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, ): """Based on embedding net output, alculate total energy. Args: - - inputs: Embedding matrix. Its shape is [nframes, natoms[0], self.embedding_width]. + - inputs: Embedding matrix. Its shape is [nframes, natoms[0], self.dim_descrpt]. - natoms: Tell atom count and element count. Its shape is [2+self.ntypes]. Returns ------- - `torch.Tensor`: Total energy with shape [nframes, natoms[0]]. """ + xx = descriptor + nf, nloc, nd = xx.shape + dtype = descriptor.dtype + device = env.DEVICE + # NOTICE in tests/pt/test_model.py + # it happens that the user directly access the data memeber self.bias_atom_e + # and set it to a wrong shape! + self.bias_atom_e = self.bias_atom_e.view([self.ntypes, self.dim_out]) + # check input dim + if nd != self.dim_descrpt: + raise ValueError( + "get an input descriptor of dim {nd}," + "which is not consistent with {self.dim_descrpt}." + ) + # check fparam dim, concate to input descriptor + if self.numb_fparam > 0: + assert fparam is not None, "fparam should not be None" + assert self.fparam_avg is not None + assert self.fparam_inv_std is not None + if fparam.shape[-1] != self.numb_fparam: + raise ValueError( + "get an input fparam of dim {fparam.shape[-1]}, ", + "which is not consistent with {self.numb_fparam}.", + ) + nb, _ = fparam.shape + t_fparam_avg = self._extend_f_avg_std(self.fparam_avg, nb) + t_fparam_inv_std = self._extend_f_avg_std(self.fparam_inv_std, nb) + fparam = (fparam - t_fparam_avg) * t_fparam_inv_std + fparam = torch.tile(fparam.reshape([nf, 1, -1]), [1, nloc, 1]) + xx = torch.cat( + [xx, fparam], + dim=-1, + ) + # check aparam dim, concate to input descriptor + if self.numb_aparam > 0: + assert aparam is not None, "aparam should not be None" + assert self.aparam_avg is not None + assert self.aparam_inv_std is not None + if aparam.shape[-1] != self.numb_aparam: + raise ValueError( + "get an input aparam of dim {aparam.shape[-1]}, ", + "which is not consistent with {self.numb_aparam}.", + ) + nb, nloc, _ = aparam.shape + t_aparam_avg = self._extend_a_avg_std(self.aparam_avg, nb, nloc) + t_aparam_inv_std = self._extend_a_avg_std(self.aparam_inv_std, nb, nloc) + aparam = (aparam - t_aparam_avg) * t_aparam_inv_std + xx = torch.cat( + [xx, aparam], + dim=-1, + ) + outs = torch.zeros_like(atype).unsqueeze(-1) # jit assertion - if self.use_tebd: - if atype_tebd is not None: - inputs = torch.concat([inputs, atype_tebd], dim=-1) - atom_energy = self.filter_layers[0](inputs) + self.bias_atom_e[ - atype - ].unsqueeze(-1) - outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] + if self.old_impl: + outs = torch.zeros_like(atype).unsqueeze(-1) # jit assertion + assert self.filter_layers_old is not None + if self.use_tebd: + atom_energy = self.filter_layers_old[0](xx) + self.bias_atom_e[ + atype + ].unsqueeze(-1) + outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] + else: + for type_i, filter_layer in enumerate(self.filter_layers_old): + mask = atype == type_i + atom_energy = filter_layer(xx) + atom_energy = atom_energy + self.bias_atom_e[type_i] + atom_energy = atom_energy * mask.unsqueeze(-1) + outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] + return {"energy": outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} else: - for type_i, filter_layer in enumerate(self.filter_layers): - mask = atype == type_i - atom_energy = filter_layer(inputs) - atom_energy = atom_energy + self.bias_atom_e[type_i] - atom_energy = atom_energy * mask.unsqueeze(-1) + if self.use_tebd: + atom_energy = ( + self.filter_layers.networks[0](xx) + self.bias_atom_e[atype] + ) outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] - return {"energy": outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} + else: + for type_i, ll in enumerate(self.filter_layers.networks): + mask = (atype == type_i).unsqueeze(-1) + mask = torch.tile(mask, (1, 1, self.dim_out)) + atom_energy = ll(xx) + atom_energy = atom_energy + self.bias_atom_e[type_i] + atom_energy = atom_energy * mask + outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] + return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} + + +@Fitting.register("ener") +@fitting_check_output +class EnergyFittingNet(InvarFitting): + def __init__( + self, + ntypes: int, + embedding_width: int, + neuron: List[int] = [128, 128, 128], + bias_atom_e: Optional[torch.Tensor] = None, + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + use_tebd: bool = True, + **kwargs, + ): + super().__init__( + "energy", + ntypes, + embedding_width, + 1, + neuron=neuron, + bias_atom_e=bias_atom_e, + resnet_dt=resnet_dt, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + activation_function=activation_function, + precision=precision, + use_tebd=use_tebd, + **kwargs, + ) @Fitting.register("direct_force") @@ -136,7 +417,7 @@ def __init__( """ super().__init__() self.ntypes = ntypes - self.embedding_width = embedding_width + self.dim_descrpt = embedding_width self.use_tebd = use_tebd self.out_dim = out_dim if not use_tebd: @@ -186,13 +467,12 @@ def forward( self, inputs: torch.Tensor, atype: torch.Tensor, - atype_tebd: Optional[torch.Tensor] = None, - rot_mat: Optional[torch.Tensor] = None, + gr: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, None]: """Based on embedding net output, alculate total energy. Args: - - inputs: Embedding matrix. Its shape is [nframes, natoms[0], self.embedding_width]. + - inputs: Embedding matrix. Its shape is [nframes, natoms[0], self.dim_descrpt]. - natoms: Tell atom count and element count. Its shape is [2+self.ntypes]. Returns @@ -201,19 +481,19 @@ def forward( """ nframes, nloc, _ = inputs.size() if self.use_tebd: - if atype_tebd is not None: - inputs = torch.concat([inputs, atype_tebd], dim=-1) + # if atype_tebd is not None: + # inputs = torch.concat([inputs, atype_tebd], dim=-1) vec_out = self.filter_layers_dipole[0]( inputs ) # Shape is [nframes, nloc, m1] assert list(vec_out.size()) == [nframes, nloc, self.out_dim] # (nf x nloc) x 1 x od vec_out = vec_out.view(-1, 1, self.out_dim) - assert rot_mat is not None + assert gr is not None # (nf x nloc) x od x 3 - rot_mat = rot_mat.view(-1, self.out_dim, 3) + gr = gr.view(-1, self.out_dim, 3) vec_out = ( - torch.bmm(vec_out, rot_mat).squeeze(-2).view(nframes, nloc, 3) + torch.bmm(vec_out, gr).squeeze(-2).view(nframes, nloc, 3) ) # Shape is [nframes, nloc, 3] else: vec_out = torch.zeros_like(atype).unsqueeze(-1) # jit assertion diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 16e80f9c20..c6fb6b27e1 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -7,9 +7,6 @@ import numpy as np import torch -from deepmd.model_format import ( - FittingOutputDef, -) from deepmd.pt.model.task.task import ( TaskBaseMethod, ) @@ -61,17 +58,9 @@ def __new__(cls, *args, **kwargs): if fitting_type in Fitting.__plugins.plugins: cls = Fitting.__plugins.plugins[fitting_type] else: - raise RuntimeError("Unknown descriptor type: " + fitting_type) + raise RuntimeError("Unknown fitting type: " + fitting_type) return super().__new__(cls) - def output_def(self) -> FittingOutputDef: - """Definition for the task Output.""" - raise NotImplementedError - - def forward(self, **kwargs): - """Task Output.""" - raise NotImplementedError - def share_params(self, base_class, shared_level, resume=False): assert ( self.__class__ == base_class.__class__ diff --git a/deepmd/pt/model/task/task.py b/deepmd/pt/model/task/task.py index a9b2efeb9a..b2dc03e4bd 100644 --- a/deepmd/pt/model/task/task.py +++ b/deepmd/pt/model/task/task.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from abc import ( + ABC, + abstractmethod, +) + import torch +from deepmd.model_format import ( + FittingOutputDef, +) -class TaskBaseMethod(torch.nn.Module): - def __init__(self, **kwargs): - """Construct a basic head for different tasks.""" - super().__init__() - def forward(self, **kwargs): - """Task Output.""" +class TaskBaseMethod(torch.nn.Module, ABC): + @abstractmethod + def output_def(self) -> FittingOutputDef: + """Definition for the task Output.""" raise NotImplementedError diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 780dbf7e62..516cbbdba6 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -4,9 +4,17 @@ Optional, ) +import numpy as np import torch import torch.nn.functional as F +from deepmd.model_format.common import PRECISION_DICT as NP_PRECISION_DICT + +from .env import ( + DEVICE, +) +from .env import PRECISION_DICT as PT_PRECISION_DICT + def get_activation_fn(activation: str) -> Callable: """Returns the activation function corresponding to `activation`.""" @@ -41,3 +49,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x else: raise RuntimeError(f"activation function {self.activation} not supported") + + +def to_numpy_array( + xx: torch.Tensor, +) -> np.ndarray: + if xx is not None: + prec = [key for key, value in PT_PRECISION_DICT.items() if value == xx.dtype] + if len(prec) == 0: + raise ValueError(f"unknown precision {xx.dtype}") + else: + prec = NP_PRECISION_DICT[prec[0]] + return xx.detach().cpu().numpy().astype(prec) if xx is not None else None + + +def to_torch_tensor( + xx: np.ndarray, +) -> torch.Tensor: + if xx is not None: + prec = [key for key, value in NP_PRECISION_DICT.items() if value == xx.dtype] + if len(prec) == 0: + raise ValueError(f"unknown precision {xx.dtype}") + else: + prec = PT_PRECISION_DICT[prec[0]] + return torch.tensor(xx, dtype=prec, device=DEVICE) if xx is not None else None diff --git a/source/tests/pt/test_ener_fitting.py b/source/tests/pt/test_ener_fitting.py new file mode 100644 index 0000000000..f0d6bdb932 --- /dev/null +++ b/source/tests/pt/test_ener_fitting.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.model_format import InvarFitting as DPInvarFitting +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.task.ener import ( + EnergyFittingNet, + InvarFitting, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestInvarFitting(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + atype = torch.tensor(self.atype_ext[:, :nloc], dtype=int, device=env.DEVICE) + + for od, distinguish_types, nfp, nap in itertools.product( + [1, 3], + [True, False], + [0, 3], + [0, 4], + ): + ft0 = InvarFitting( + "foo", + self.nt, + dd0.dim_out, + od, + numb_fparam=nfp, + numb_aparam=nap, + use_tebd=(not distinguish_types), + ).to(env.DEVICE) + ft1 = DPInvarFitting.deserialize(ft0.serialize()) + ft2 = InvarFitting.deserialize(ft0.serialize()) + + if nfp > 0: + ifp = torch.tensor( + rng.normal(size=(self.nf, nfp)), dtype=dtype, device=env.DEVICE + ) + else: + ifp = None + if nap > 0: + iap = torch.tensor( + rng.normal(size=(self.nf, self.nloc, nap)), + dtype=dtype, + device=env.DEVICE, + ) + else: + iap = None + + ret0 = ft0(rd0, atype, fparam=ifp, aparam=iap) + ret1 = ft1( + rd0.detach().cpu().numpy(), + atype.detach().cpu().numpy(), + fparam=to_numpy_array(ifp), + aparam=to_numpy_array(iap), + ) + ret2 = ft2(rd0, atype, fparam=ifp, aparam=iap) + np.testing.assert_allclose( + to_numpy_array(ret0["foo"]), + ret1["foo"], + ) + np.testing.assert_allclose( + to_numpy_array(ret0["foo"]), + to_numpy_array(ret2["foo"]), + ) + + def test_new_old( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + dd = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) + rd0, _, _, _, _ = dd( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + atype = torch.tensor(self.atype_ext[:, :nloc], dtype=int, device=env.DEVICE) + + od = 1 + for distinguish_types in itertools.product( + [True, False], + ): + ft0 = EnergyFittingNet( + self.nt, + dd.dim_out, + distinguish_types=distinguish_types, + ).to(env.DEVICE) + ft1 = EnergyFittingNet( + self.nt, + dd.dim_out, + distinguish_types=distinguish_types, + old_impl=True, + ).to(env.DEVICE) + dd0 = ft0.state_dict() + dd1 = ft1.state_dict() + for kk, vv in dd1.items(): + new_kk = kk + new_kk = new_kk.replace("filter_layers_old", "filter_layers.networks") + new_kk = new_kk.replace("deep_layers", "layers") + new_kk = new_kk.replace("final_layer", "layers.3") + dd1[kk] = dd0[new_kk] + if kk.split(".")[-1] in ["idt", "bias"]: + dd1[kk] = dd1[kk].unsqueeze(0) + dd1["bias_atom_e"] = dd0["bias_atom_e"] + ft1.load_state_dict(dd1) + ret0 = ft0(rd0, atype) + ret1 = ft1(rd0, atype) + np.testing.assert_allclose( + to_numpy_array(ret0["energy"]), + to_numpy_array(ret1["energy"]), + ) + + def test_jit( + self, + ): + for od, distinguish_types, nfp, nap in itertools.product( + [1, 3], + [True, False], + [0, 3], + [0, 4], + ): + ft0 = InvarFitting( + "foo", + self.nt, + 9, + od, + numb_fparam=nfp, + numb_aparam=nap, + use_tebd=(not distinguish_types), + ).to(env.DEVICE) + torch.jit.script(ft0) diff --git a/source/tests/pt/test_fitting_net.py b/source/tests/pt/test_fitting_net.py index 3feb4f4739..ed2c428de5 100644 --- a/source/tests/pt/test_fitting_net.py +++ b/source/tests/pt/test_fitting_net.py @@ -102,25 +102,25 @@ def test_consistency(self): my_fn = EnergyFittingNet( self.ntypes, self.embedding_width, - self.n_neuron, - self.dp_fn.bias_atom_e, - use_tebd=False, + neuron=self.n_neuron, + bias_atom_e=self.dp_fn.bias_atom_e, + distinguish_types=True, ) for name, param in my_fn.named_parameters(): - matched = re.match("filter_layers\.(\d).deep_layers\.(\d)\.([a-z]+)", name) + matched = re.match( + "filter_layers\.networks\.(\d).layers\.(\d)\.([a-z]+)", name + ) key = None if matched: + if int(matched.group(2)) == len(self.n_neuron): + layer_id = -1 + else: + layer_id = matched.group(2) key = gen_key( type_id=matched.group(1), - layer_id=matched.group(2), + layer_id=layer_id, w_or_b=matched.group(3), ) - else: - matched = re.match("filter_layers\.(\d).final_layer\.([a-z]+)", name) - if matched: - key = gen_key( - type_id=matched.group(1), layer_id=-1, w_or_b=matched.group(2) - ) assert key is not None var = values[key] with torch.no_grad(): @@ -132,7 +132,7 @@ def test_consistency(self): ret = my_fn(embedding, atype) my_energy = ret["energy"] my_energy = my_energy.detach() - self.assertTrue(np.allclose(dp_energy, my_energy.numpy().reshape([-1]))) + np.testing.assert_allclose(dp_energy, my_energy.numpy().reshape([-1])) if __name__ == "__main__": diff --git a/source/tests/pt/test_model.py b/source/tests/pt/test_model.py index 5bbbc9e352..c6595e6471 100644 --- a/source/tests/pt/test_model.py +++ b/source/tests/pt/test_model.py @@ -53,23 +53,24 @@ VariableState = collections.namedtuple("VariableState", ["value", "gradient"]) -def torch2tf(torch_name): +def torch2tf(torch_name, last_layer_id=None): fields = torch_name.split(".") offset = int(fields[2] == "networks") element_id = int(fields[2 + offset]) if fields[0] == "descriptor": layer_id = int(fields[4 + offset]) + 1 weight_type = fields[5 + offset] - return "filter_type_all/%s_%d_%d:0" % (weight_type, layer_id, element_id) - elif fields[3] == "deep_layers": - layer_id = int(fields[4]) - weight_type = fields[5] - return "layer_%d_type_%d/%s:0" % (layer_id, element_id, weight_type) - elif fields[3] == "final_layer": - weight_type = fields[4] - return "final_layer_type_%d/%s:0" % (element_id, weight_type) + ret = "filter_type_all/%s_%d_%d:0" % (weight_type, layer_id, element_id) + elif fields[0] == "fitting_net": + layer_id = int(fields[4 + offset]) + weight_type = fields[5 + offset] + if layer_id != last_layer_id: + ret = "layer_%d_type_%d/%s:0" % (layer_id, element_id, weight_type) + else: + ret = "final_layer_type_%d/%s:0" % (element_id, weight_type) else: raise RuntimeError("Unexpected parameter name: %s" % torch_name) + return ret class DpTrainer: @@ -290,7 +291,7 @@ def test_consistency(self): "neuron": self.filter_neuron, "axis_neuron": self.axis_neuron, }, - "fitting_net": {"neuron": self.n_neuron}, + "fitting_net": {"neuron": self.n_neuron, "distinguish_types": True}, "data_stat_nbatch": self.data_stat_nbatch, "type_map": self.type_map, }, @@ -323,7 +324,7 @@ def test_consistency(self): # Keep parameter value consistency between 2 implentations for name, param in my_model.named_parameters(): name = name.replace("sea.", "") - var_name = torch2tf(name) + var_name = torch2tf(name, last_layer_id=len(self.n_neuron)) var = vs_dict[var_name].value with torch.no_grad(): src = torch.from_numpy(var) @@ -404,7 +405,7 @@ def step(step_id): for name, param in my_model.named_parameters(): name = name.replace("sea.", "") - var_name = torch2tf(name) + var_name = torch2tf(name, last_layer_id=len(self.n_neuron)) var_grad = vs_dict[var_name].gradient param_grad = param.grad.cpu() var_grad = torch.tensor(var_grad) diff --git a/source/tests/pt/test_se_e2_a.py b/source/tests/pt/test_se_e2_a.py index c0a106cb16..0da80ea1ea 100644 --- a/source/tests/pt/test_se_e2_a.py +++ b/source/tests/pt/test_se_e2_a.py @@ -25,6 +25,9 @@ PRECISION_DICT, ) +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) from .test_mlp import ( get_tols, ) @@ -32,36 +35,6 @@ dtype = env.GLOBAL_PT_FLOAT_PRECISION -class TestCaseSingleFrameWithNlist: - def setUp(self): - # nloc == 3, nall == 4 - self.nloc = 3 - self.nall = 4 - self.nf, self.nt = 1, 2 - self.coord_ext = np.array( - [ - [0, 0, 0], - [0, 1, 0], - [0, 0, 1], - [0, -2, 0], - ], - dtype=np.float64, - ).reshape([1, self.nall * 3]) - self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall]) - # sel = [5, 2] - self.sel = [5, 2] - self.nlist = np.array( - [ - [1, 3, -1, -1, -1, 2, -1], - [0, -1, -1, -1, -1, 2, -1], - [0, 1, -1, -1, -1, 0, -1], - ], - dtype=int, - ).reshape([1, self.nloc, sum(self.sel)]) - self.rcut = 0.4 - self.rcut_smth = 2.2 - - # to be merged with the tf test case @unittest.skipIf(not support_se_e2_a, "EnvMat not supported") class TestDescrptSeA(unittest.TestCase, TestCaseSingleFrameWithNlist):