Skip to content

Commit

Permalink
fix: pt: energy model forward lower is not tested and has bugs. (#3235)
Browse files Browse the repository at this point in the history
Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
  • Loading branch information
wanghan-iapcm and Han Wang authored Feb 6, 2024
1 parent 18c43f6 commit 13a781f
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 3 deletions.
7 changes: 4 additions & 3 deletions deepmd/pt/model/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def forward(
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(
-3
)
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-3)
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
else:
model_predict["force"] = model_ret["dforce"]
else:
Expand All @@ -64,7 +64,7 @@ def forward_lower(
mapping: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
):
model_ret = self.common_forward_lower(
model_ret = self.forward_common_lower(
extended_coord,
extended_atype,
nlist,
Expand All @@ -77,10 +77,11 @@ def forward_lower(
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad("energy"):
model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2)
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["extended_virial"] = model_ret[
"energy_derv_c"
].squeeze(-3)
].squeeze(-2)
else:
assert model_ret["dforce"] is not None
model_predict["dforce"] = model_ret["dforce"]
Expand Down
143 changes: 143 additions & 0 deletions source/tests/pt/test_dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from deepmd.pt.model.model.ener import (
DPModel,
EnergyModel,
)
from deepmd.pt.model.task.ener import (
InvarFitting,
Expand Down Expand Up @@ -386,3 +387,145 @@ def test_nlist_lt(self):
to_torch_tensor(nlist),
)
np.testing.assert_allclose(self.expected_nlist, to_numpy_array(nlist1))


class TestEnergyModel(unittest.TestCase, TestCaseSingleFrameWithoutNlist):
def setUp(self):
TestCaseSingleFrameWithoutNlist.setUp(self)

def test_self_consistency(self):
nf, nloc = self.atype.shape
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
).to(env.DEVICE)
ft = InvarFitting(
"energy",
self.nt,
ds.get_dim_out(),
1,
distinguish_types=ds.distinguish_types(),
).to(env.DEVICE)
type_map = ["foo", "bar"]
# TODO: dirty hack to avoid data stat!!!
md0 = EnergyModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE)
md1 = EnergyModel.deserialize(md0.serialize()).to(env.DEVICE)
args = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]]
ret0 = md0.forward(*args)
ret1 = md1.forward(*args)
np.testing.assert_allclose(
to_numpy_array(ret0["atom_energy"]),
to_numpy_array(ret1["atom_energy"]),
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy"]),
to_numpy_array(ret1["energy"]),
)
np.testing.assert_allclose(
to_numpy_array(ret0["force"]),
to_numpy_array(ret1["force"]),
)
np.testing.assert_allclose(
to_numpy_array(ret0["virial"]),
to_numpy_array(ret1["virial"]),
)
ret0 = md0.forward(*args, do_atomic_virial=True)
ret1 = md1.forward(*args, do_atomic_virial=True)
np.testing.assert_allclose(
to_numpy_array(ret0["atom_virial"]),
to_numpy_array(ret1["atom_virial"]),
)

coord_ext, atype_ext, mapping = extend_coord_with_ghosts(
to_torch_tensor(self.coord),
to_torch_tensor(self.atype),
to_torch_tensor(self.cell),
self.rcut,
)
nlist = build_neighbor_list(
coord_ext,
atype_ext,
self.nloc,
self.rcut,
self.sel,
distinguish_types=md0.distinguish_types(),
)
args = [coord_ext, atype_ext, nlist]
ret2 = md0.forward_lower(*args, do_atomic_virial=True)
# check the consistency between the reduced virial from
# forward and forward_lower
np.testing.assert_allclose(
to_numpy_array(ret0["virial"]),
to_numpy_array(ret2["virial"]),
)


class TestEnergyModelLower(unittest.TestCase, TestCaseSingleFrameWithNlist):
def setUp(self):
TestCaseSingleFrameWithNlist.setUp(self)

def test_self_consistency(self):
nf, nloc, nnei = self.nlist.shape
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
).to(env.DEVICE)
ft = InvarFitting(
"energy",
self.nt,
ds.get_dim_out(),
1,
distinguish_types=ds.distinguish_types(),
).to(env.DEVICE)
type_map = ["foo", "bar"]
# TODO: dirty hack to avoid data stat!!!
md0 = EnergyModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE)
md1 = EnergyModel.deserialize(md0.serialize()).to(env.DEVICE)
args = [
to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist]
]
ret0 = md0.forward_lower(*args)
ret1 = md1.forward_lower(*args)
np.testing.assert_allclose(
to_numpy_array(ret0["atom_energy"]),
to_numpy_array(ret1["atom_energy"]),
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy"]),
to_numpy_array(ret1["energy"]),
)
np.testing.assert_allclose(
to_numpy_array(ret0["extended_force"]),
to_numpy_array(ret1["extended_force"]),
)
np.testing.assert_allclose(
to_numpy_array(ret0["virial"]),
to_numpy_array(ret1["virial"]),
)
ret0 = md0.forward_lower(*args, do_atomic_virial=True)
ret1 = md1.forward_lower(*args, do_atomic_virial=True)
np.testing.assert_allclose(
to_numpy_array(ret0["extended_virial"]),
to_numpy_array(ret1["extended_virial"]),
)

def test_jit(self):
nf, nloc, nnei = self.nlist.shape
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
).to(env.DEVICE)
ft = InvarFitting(
"energy",
self.nt,
ds.get_dim_out(),
1,
distinguish_types=ds.distinguish_types(),
).to(env.DEVICE)
type_map = ["foo", "bar"]
# TODO: dirty hack to avoid data stat!!!
md0 = EnergyModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE)
torch.jit.script(md0)

0 comments on commit 13a781f

Please sign in to comment.