Skip to content

Commit

Permalink
feat: add dipole consistency test
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Feb 22, 2024
1 parent cf21b7a commit e9dcf0f
Show file tree
Hide file tree
Showing 16 changed files with 227 additions and 8 deletions.
2 changes: 2 additions & 0 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def __init__(
r_differentiable: bool = True,
c_differentiable: bool = True,
old_impl=False,
# not used
seed: Optional[int] = None,
):
# seed, uniform_seed are not included
if tot_ener_zero:
Expand Down
18 changes: 12 additions & 6 deletions deepmd/tf/fit/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@ class DipoleFittingSeA(Fitting):
Parameters
----------
descrpt : tf.Tensor
The descrptor
ntypes
The ntypes of the descrptor :math:`\mathcal{D}`
dim_descrpt
The dimension of the descrptor :math:`\mathcal{D}`
embedding_width
The rotation matrix dimension of the descrptor :math:`\mathcal{D}`
neuron : List[int]
Number of neurons in each hidden layer of the fitting net
resnet_dt : bool
Expand All @@ -59,7 +63,9 @@ class DipoleFittingSeA(Fitting):

def __init__(
self,
descrpt: tf.Tensor,
ntypes: int,
dim_descrpt: int,
embedding_width: int,
neuron: List[int] = [120, 120, 120],
resnet_dt: bool = True,
sel_type: Optional[List[int]] = None,
Expand All @@ -70,8 +76,8 @@ def __init__(
**kwargs,
) -> None:
"""Constructor."""
self.ntypes = descrpt.get_ntypes()
self.dim_descrpt = descrpt.get_dim_out()
self.ntypes = ntypes
self.dim_descrpt = dim_descrpt
self.n_neuron = neuron
self.resnet_dt = resnet_dt
self.sel_type = sel_type
Expand All @@ -85,7 +91,7 @@ def __init__(
self.seed_shift = one_layer_rand_seed_shift()
self.fitting_activation_fn = get_activation_func(activation_function)
self.fitting_precision = get_precision(precision)
self.dim_rot_mat_1 = descrpt.get_dim_rot_mat_1()
self.dim_rot_mat_1 = embedding_width
self.dim_rot_mat = self.dim_rot_mat_1 * 3
self.useBN = False
self.fitting_net_variables = None
Expand Down
6 changes: 4 additions & 2 deletions deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ class EnerFitting(Fitting):
Parameters
----------
descrpt
The descrptor :math:`\mathcal{D}`
ntypes
The ntypes of the descrptor :math:`\mathcal{D}`
dim_descrpt
The dimension of the descrptor :math:`\mathcal{D}`
neuron
Number of neurons :math:`N` in each hidden layer of the fitting net
resnet_dt
Expand Down
190 changes: 190 additions & 0 deletions source/tests/consistent/fitting/test_dipole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest
from typing import (
Any,
Tuple,
)

import numpy as np

from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingDP
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)

from ..common import (
INSTALLED_PT,
INSTALLED_TF,
CommonTest,
parameterized,
)
from .common import (
FittingTest,
)

if INSTALLED_PT:
import torch

from deepmd.pt.model.task.dipole import DipoleFittingNet as DipoleFittingPT
from deepmd.pt.utils.env import DEVICE as PT_DEVICE
else:
DipoleFittingPT = object
if INSTALLED_TF:
from deepmd.tf.fit.dipole import DipoleFittingSeA as DipoleFittingTF
else:
DipoleFittingTF = object
from deepmd.utils.argcheck import (
fitting_dipole,
)


@parameterized(
(True, False), # resnet_dt
("float64", "float32"), # precision
(True, False), # mixed_types
)
class TestDipole(CommonTest, FittingTest, unittest.TestCase):
@property
def data(self) -> dict:
(
resnet_dt,
precision,
mixed_types,
) = self.param
return {
"neuron": [5, 5, 5],
"resnet_dt": resnet_dt,
"precision": precision,
"seed": 20240217,
}

@property
def skip_tf(self) -> bool:
(
resnet_dt,
precision,
mixed_types,
) = self.param
# TODO: mixed_types
return mixed_types or CommonTest.skip_pt

@property
def skip_pt(self) -> bool:
(
resnet_dt,
precision,
mixed_types,
) = self.param
return CommonTest.skip_pt

tf_class = DipoleFittingTF
dp_class = DipoleFittingDP
pt_class = DipoleFittingPT
args = fitting_dipole()

def setUp(self):
CommonTest.setUp(self)

self.ntypes = 2
self.natoms = np.array([6, 6, 2, 4], dtype=np.int32)
self.inputs = np.ones((1, 6, 20), dtype=GLOBAL_NP_FLOAT_PRECISION)
self.gr = np.ones((1, 6, 30, 3), dtype=GLOBAL_NP_FLOAT_PRECISION)
self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32)
# inconsistent if not sorted
self.atype.sort()

@property
def addtional_data(self) -> dict:
(
resnet_dt,
precision,
mixed_types,
) = self.param
return {
"ntypes": self.ntypes,
"dim_descrpt": self.inputs.shape[-1],
"mixed_types": mixed_types,
"var_name": "dipole",
"embedding_width": 30,
}

def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]:
(
resnet_dt,
precision,
mixed_types,
) = self.param
return self.build_tf_fitting(
obj,
self.inputs.ravel(),
self.natoms,
self.atype,
suffix,
)

def eval_pt(self, pt_obj: Any) -> Any:
(
resnet_dt,
precision,
mixed_types,
) = self.param
return (
pt_obj(
torch.from_numpy(self.inputs).to(device=PT_DEVICE),
torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_DEVICE),
torch.from_numpy(self.gr).to(device=PT_DEVICE),
None,
)["dipole"]
.detach()
.cpu()
.numpy()
)

def eval_dp(self, dp_obj: Any) -> Any:
(
resnet_dt,
precision,
mixed_types,
) = self.param
return dp_obj(
self.inputs,
self.atype.reshape(1, -1),
self.gr,
None,
)["dipole"]

def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]:
if backend == self.RefBackend.TF:
# shape is not same
ret = ret[0].reshape(-1, self.natoms[0], 1)
return (ret,)

@property
def rtol(self) -> float:
"""Relative tolerance for comparing the return value."""
(
resnet_dt,
precision,
mixed_types,
) = self.param
if precision == "float64":
return 1e-10
elif precision == "float32":
return 1e-4
else:
raise ValueError(f"Unknown precision: {precision}")

@property
def atol(self) -> float:
"""Absolute tolerance for comparing the return value."""
(
resnet_dt,
precision,
mixed_types,
) = self.param
if precision == "float64":
return 1e-10
elif precision == "float32":
return 1e-4
else:
raise ValueError(f"Unknown precision: {precision}")
3 changes: 3 additions & 0 deletions source/tests/tf/test_data_large_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def test_data_mixed_type(self):
descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
typeebd_param = jdata["model"]["type_embedding"]
typeebd = TypeEmbedNet(
Expand Down Expand Up @@ -311,6 +312,7 @@ def test_stripped_data_mixed_type(self):
descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
typeebd_param = jdata["model"]["type_embedding"]
typeebd = TypeEmbedNet(
Expand Down Expand Up @@ -508,6 +510,7 @@ def test_compressible_data_mixed_type(self):
descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
typeebd_param = jdata["model"]["type_embedding"]
typeebd = TypeEmbedNet(
Expand Down
1 change: 1 addition & 0 deletions source/tests/tf/test_fitting_ener_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def test_fitting(self):
descrpt = DescrptSeA(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)

# model._compute_dstats([test_data['coord']], [test_data['box']], [test_data['type']], [test_data['natoms_vec']], [test_data['default_mesh']])
Expand Down
3 changes: 3 additions & 0 deletions source/tests/tf/test_model_se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def test_model_atom_ener(self):
descrpt = DescrptSeA(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
model = EnerModel(descrpt, fitting)

Expand Down Expand Up @@ -157,6 +158,7 @@ def test_model(self):
descrpt = DescrptSeA(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
model = EnerModel(descrpt, fitting)

Expand Down Expand Up @@ -302,6 +304,7 @@ def test_model_atom_ener_type_embedding(self):
descrpt = DescrptSeA(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
model = EnerModel(descrpt, fitting, typeebd=typeebd)

Expand Down
1 change: 1 addition & 0 deletions source/tests/tf/test_model_se_a_aparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_model(self):
descrpt = DescrptSeA(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
model = EnerModel(descrpt, fitting)

Expand Down
1 change: 1 addition & 0 deletions source/tests/tf/test_model_se_a_ebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def test_model(self):
)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
fitting = EnerFitting(
**jdata["model"]["fitting_net"],
)
Expand Down
1 change: 1 addition & 0 deletions source/tests/tf/test_model_se_a_ebd_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def test_model(self):
)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
fitting = EnerFitting(
**jdata["model"]["fitting_net"],
)
Expand Down
1 change: 1 addition & 0 deletions source/tests/tf/test_model_se_a_fparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_model(self):
descrpt = DescrptSeA(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
# descrpt = DescrptSeA(jdata['model']['descriptor'])
# fitting = EnerFitting(jdata['model']['fitting_net'], descrpt)
Expand Down
1 change: 1 addition & 0 deletions source/tests/tf/test_model_se_a_srtab.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def test_model(self):
descrpt = DescrptSeA(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
# descrpt = DescrptSeA(jdata['model']['descriptor'])
# fitting = EnerFitting(jdata['model']['fitting_net'], descrpt)
Expand Down
1 change: 1 addition & 0 deletions source/tests/tf/test_model_se_a_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_model(self):
descrpt = DescrptSeA(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
typeebd_param = jdata["model"]["type_embedding"]
typeebd = TypeEmbedNet(
Expand Down
4 changes: 4 additions & 0 deletions source/tests/tf/test_model_se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def test_model(self):
descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
typeebd_param = jdata["model"]["type_embedding"]
typeebd = TypeEmbedNet(
Expand Down Expand Up @@ -295,6 +296,7 @@ def test_compressible_model(self):
descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
typeebd_param = jdata["model"]["type_embedding"]
typeebd = TypeEmbedNet(
Expand Down Expand Up @@ -523,6 +525,7 @@ def test_stripped_type_embedding_model(self):
descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
typeebd_param = jdata["model"]["type_embedding"]
typeebd = TypeEmbedNet(
Expand Down Expand Up @@ -762,6 +765,7 @@ def test_smoothness_of_stripped_type_embedding_smooth_model(self):
descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
typeebd_param = jdata["model"]["type_embedding"]
typeebd = TypeEmbedNet(
Expand Down
Loading

0 comments on commit e9dcf0f

Please sign in to comment.