Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: directional nlist #4052

Merged
merged 7 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 123 additions & 5 deletions deepmd/pt/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@
).view(batch_size, nall * 3)
if isinstance(sel, int):
sel = [sel]
nsel = sum(sel)
# nloc x 3
coord0 = coord1[:, : nloc * 3]
# nloc x nall x 3
Expand All @@ -126,8 +125,26 @@
# nloc x (nall-1)
rr = rr[:, :, 1:]
nlist = nlist[:, :, 1:]

return _trim_mask_distinguish_nlist(

Check warning on line 129 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L129

Added line #L129 was not covered by tests
is_vir, atype, rr, nlist, rcut, sel, distinguish_types
)


def _trim_mask_distinguish_nlist(

Check warning on line 134 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L134

Added line #L134 was not covered by tests
is_vir_cntl: torch.Tensor,
atype_neig: torch.Tensor,
rr: torch.Tensor,
nlist: torch.Tensor,
rcut: float,
sel: List[int],
distinguish_types: bool,
) -> torch.Tensor:
"""Trim the size of nlist, mask if any central atom is virtual, distinguish types if necessary."""
nsel = sum(sel)

Check warning on line 144 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L144

Added line #L144 was not covered by tests
# nloc x nsel
nnei = rr.shape[2]
batch_size, nloc, nnei = rr.shape
assert (batch_size, nloc) == is_vir_cntl.shape

Check warning on line 147 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L146-L147

Added lines #L146 - L147 were not covered by tests
if nsel <= nnei:
rr = rr[:, :, :nsel]
nlist = nlist[:, :, :nsel]
Expand All @@ -147,15 +164,116 @@
)
assert list(nlist.shape) == [batch_size, nloc, nsel]
nlist = torch.where(
torch.logical_or((rr > rcut), is_vir[:, :nloc, None]), -1, nlist
torch.logical_or((rr > rcut), is_vir_cntl[:, :nloc, None]), -1, nlist
)

if distinguish_types:
return nlist_distinguish_types(nlist, atype, sel)
return nlist_distinguish_types(nlist, atype_neig, sel)

Check warning on line 170 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L170

Added line #L170 was not covered by tests
else:
return nlist


def build_directional_neighbor_list(

Check warning on line 175 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L175

Added line #L175 was not covered by tests
coord_cntl: torch.Tensor,
atype_cntl: torch.Tensor,
coord_neig: torch.Tensor,
atype_neig: torch.Tensor,
rcut: float,
sel: Union[int, List[int]],
distinguish_types: bool = True,
) -> torch.Tensor:
"""Build directional neighbor list.

With each central atom, all the neighbor atoms in the cut-off radius will
be recorded in the neighbor list. The maximum neighbors is nsel. If the real
number of neighbors is larger than nsel, the neighbors will be sorted with the
distance and the first nsel neighbors are kept.

Important: the central and neighboring atoms are assume to be different atoms.

Parameters
----------
coord_central : torch.Tensor
coordinates of central atoms. assumed to be local atoms.
shape [batch_size, nloc_central x 3]
atype_central : torch.Tensor
atomic types of central atoms. shape [batch_size, nloc_central]
if type < 0 the atom is treated as virtual atoms.
coord_neighbor : torch.Tensor
extended coordinates of neighbors atoms. shape [batch_size, nall_neighbor x 3]
atype_central : torch.Tensor
extended atomic types of neighbors atoms. shape [batch_size, nall_neighbor]
if type < 0 the atom is treated as virtual atoms.
rcut : float
cut-off radius
sel : int or List[int]
maximal number of neighbors (of each type).
if distinguish_types==True, nsel should be list and
the length of nsel should be equal to number of
types.
distinguish_types : bool
distinguish different types.

Returns
-------
neighbor_list : torch.Tensor
Neighbor list of shape [batch_size, nloc_central, nsel], the neighbors
are stored in an ascending order. If the number of neighbors is less than nsel,
the positions are masked with -1. The neighbor list of an atom looks like
|------ nsel ------|
xx xx xx xx -1 -1 -1
if distinguish_types==True and we have two types
|---- nsel[0] -----| |---- nsel[1] -----|
xx xx xx xx -1 -1 -1 xx xx xx -1 -1 -1 -1
For virtual atoms all neighboring positions are filled with -1.
"""
batch_size = coord_cntl.shape[0]
coord_cntl = coord_cntl.view(batch_size, -1)
nloc_cntl = coord_cntl.shape[1] // 3
coord_neig = coord_neig.view(batch_size, -1)
nall_neig = coord_neig.shape[1] // 3

Check warning on line 233 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L229-L233

Added lines #L229 - L233 were not covered by tests
# fill virtual atoms with large coords so they are not neighbors of any
# real atom.
if coord_neig.numel() > 0:
xmax = torch.max(coord_cntl) + 2.0 * rcut

Check warning on line 237 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L236-L237

Added lines #L236 - L237 were not covered by tests
else:
xmax = (

Check warning on line 239 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L239

Added line #L239 was not covered by tests
torch.zeros(1, dtype=coord_neig.dtype, device=coord_neig.device)
+ 2.0 * rcut
)
# nf x nloc
is_vir_cntl = atype_cntl < 0

Check warning on line 244 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L244

Added line #L244 was not covered by tests
# nf x nall
is_vir_neig = atype_neig < 0

Check warning on line 246 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L246

Added line #L246 was not covered by tests
# nf x nloc x 3
coord_cntl = coord_cntl.view(batch_size, nloc_cntl, 3)

Check warning on line 248 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L248

Added line #L248 was not covered by tests
# nf x nall x 3
coord_neig = torch.where(

Check warning on line 250 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L250

Added line #L250 was not covered by tests
is_vir_neig[:, :, None], xmax, coord_neig.view(batch_size, nall_neig, 3)
).view(batch_size, nall_neig, 3)
# nsel
if isinstance(sel, int):
sel = [sel]

Check warning on line 255 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L254-L255

Added lines #L254 - L255 were not covered by tests
# nloc x nall x 3
diff = coord_neig[:, None, :, :] - coord_cntl[:, :, None, :]
assert list(diff.shape) == [batch_size, nloc_cntl, nall_neig, 3]

Check warning on line 258 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L257-L258

Added lines #L257 - L258 were not covered by tests
# nloc x nall
rr = torch.linalg.norm(diff, dim=-1)
rr, nlist = torch.sort(rr, dim=-1)

Check warning on line 261 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L260-L261

Added lines #L260 - L261 were not covered by tests

# We assume that the central and neighbor atoms are diffferent,
# thus we do not need to exclude self-neighbors.
# # if central atom has two zero distances, sorting sometimes can not exclude itself
# rr -= torch.eye(nloc_cntl, nall_neig, dtype=rr.dtype, device=rr.device).unsqueeze(0)
# rr, nlist = torch.sort(rr, dim=-1)
# # nloc x (nall-1)
# rr = rr[:, :, 1:]
# nlist = nlist[:, :, 1:]

return _trim_mask_distinguish_nlist(

Check warning on line 272 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L272

Added line #L272 was not covered by tests
is_vir_cntl, atype_neig, rr, nlist, rcut, sel, distinguish_types
)


def nlist_distinguish_types(
nlist: torch.Tensor,
atype: torch.Tensor,
Expand Down
70 changes: 69 additions & 1 deletion source/tests/pt/model/test_nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
env,
)
from deepmd.pt.utils.nlist import (
build_directional_neighbor_list,
build_multiple_neighbor_list,
build_neighbor_list,
extend_coord_with_ghosts,
Expand Down Expand Up @@ -62,6 +63,7 @@ def test_build_notype(self):
ecoord, eatype, mapping = extend_coord_with_ghosts(
self.coord, self.atype, self.cell, self.rcut
)
# test normal sel
nlist = build_neighbor_list(
ecoord,
eatype,
Expand All @@ -70,14 +72,29 @@ def test_build_notype(self):
sum(self.nsel),
distinguish_types=False,
)
torch.testing.assert_close(nlist[0], nlist[1])
nlist_mask = nlist[0] == -1
nlist_loc = mapping[0][nlist[0]]
nlist_loc[nlist_mask] = -1
torch.testing.assert_close(
torch.sort(nlist_loc, dim=-1)[0],
torch.sort(self.ref_nlist, dim=-1)[0],
)
# test a very large sel
nlist = build_neighbor_list(
ecoord,
eatype,
self.nloc,
self.rcut,
sum(self.nsel) + 300, # +300, real nnei==224
distinguish_types=False,
)
nlist_mask = nlist[0] == -1
nlist_loc = mapping[0][nlist[0]]
nlist_loc[nlist_mask] = -1
torch.testing.assert_close(
torch.sort(nlist_loc, descending=True, dim=-1)[0][:, : sum(self.nsel)],
torch.sort(self.ref_nlist, descending=True, dim=-1)[0],
)

def test_build_type(self):
ecoord, eatype, mapping = extend_coord_with_ghosts(
Expand Down Expand Up @@ -218,3 +235,54 @@ def test_extend_coord(self):
rtol=self.prec,
atol=self.prec,
)

def test_build_directional_nlist(self):
"""Directional nlist is tested against the standard nlist implementation."""
ecoord, eatype, mapping = extend_coord_with_ghosts(
self.coord, self.atype, self.cell, self.rcut
)
for distinguish_types, mysel in zip([True, False], [sum(self.nsel), 300]):
# full neighbor list
nlist_full = build_neighbor_list(
ecoord,
eatype,
self.nloc,
self.rcut,
sum(self.nsel),
distinguish_types=distinguish_types,
)
# central as part of the system
nlist = build_directional_neighbor_list(
ecoord[:, 3:6],
eatype[:, 1:2],
torch.concat(
[
ecoord[:, 0:3],
torch.zeros(
[self.nf, 3], dtype=dtype, device=env.DEVICE
), # placeholder
ecoord[:, 6:],
],
dim=1,
),
torch.concat(
[
eatype[:, 0:1],
-1
* torch.ones(
[self.nf, 1], dtype=int, device=env.DEVICE
), # placeholder
eatype[:, 2:],
],
dim=1,
),
self.rcut,
mysel,
distinguish_types=distinguish_types,
)
torch.testing.assert_close(nlist[0], nlist[1])
torch.testing.assert_close(nlist[0], nlist[2])
torch.testing.assert_close(
torch.sort(nlist[0], descending=True, dim=-1)[0][:, : sum(self.nsel)],
torch.sort(nlist_full[0][1:2], descending=True, dim=-1)[0],
)
Loading