Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add se_r descriptor #3338

Merged
merged 23 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
af51721
feat: add se_r descriptor
anyangml Feb 26, 2024
c0af6fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
8a8107c
fix: UTs, removed old impl
anyangml Feb 26, 2024
fb6340b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
2eb4041
fix: pre-commit
anyangml Feb 26, 2024
6c5224f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
b771db4
fix: update se_r output
anyangml Feb 26, 2024
8a1a86c
chore: refactor
anyangml Feb 26, 2024
50cdfe0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
84d61da
feat: add numpy impl
anyangml Feb 26, 2024
08a6988
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
1f0fd99
fix: UTs
anyangml Feb 27, 2024
c07e02c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 27, 2024
3485608
Merge branch 'devel' into devel
anyangml Feb 27, 2024
e5b074e
gix: match serialization
anyangml Feb 27, 2024
8265242
Merge branch 'devel' into devel
anyangml Feb 27, 2024
3fe3ed4
chore: refactor device
anyangml Feb 27, 2024
9d23e96
chore: refactor device
anyangml Feb 27, 2024
c073241
fix: UTs
anyangml Feb 27, 2024
ce85c0f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 27, 2024
72d7a37
fix: dtype
anyangml Feb 27, 2024
54375bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 27, 2024
184e6a0
Merge branch 'devel' into devel
anyangml Feb 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions deepmd/pt/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from .env_mat import (
prod_env_mat_se_a,
prod_env_mat_se_r,
)
from .gaussian_lcc import (
DescrptGaussianLcc,
Expand All @@ -27,6 +28,9 @@
DescrptBlockSeA,
DescrptSeA,
)
from .se_r import (

Check warning on line 31 in deepmd/pt/model/descriptor/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/__init__.py#L31

Added line #L31 was not covered by tests
DescrptSeR,
)

__all__ = [
"Descriptor",
Expand All @@ -35,9 +39,11 @@
"DescrptBlockSeA",
"DescrptBlockSeAtten",
"DescrptSeA",
"DescrptSeR",
"DescrptDPA1",
"DescrptDPA2",
"prod_env_mat_se_a",
"prod_env_mat_se_r",
"DescrptGaussianLcc",
"DescrptBlockHybrid",
"DescrptBlockRepformers",
Expand Down
53 changes: 53 additions & 0 deletions deepmd/pt/model/descriptor/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,29 @@
return env_mat_se_a, diff * mask.unsqueeze(-1), weight


def _make_env_mat_se_r(nlist, coord, rcut: float, ruct_smth: float):

Check warning on line 33 in deepmd/pt/model/descriptor/env_mat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/env_mat.py#L33

Added line #L33 was not covered by tests
anyangml marked this conversation as resolved.
Show resolved Hide resolved
"""Make smooth environment matrix."""
bsz, natoms, nnei = nlist.shape
coord = coord.view(bsz, -1, 3)
nall = coord.shape[1]
mask = nlist >= 0

Check warning on line 38 in deepmd/pt/model/descriptor/env_mat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/env_mat.py#L35-L38

Added lines #L35 - L38 were not covered by tests
# nlist = nlist * mask ## this impl will contribute nans in Hessian calculation.
nlist = torch.where(mask, nlist, nall - 1)
coord_l = coord[:, :natoms].view(bsz, -1, 1, 3)
index = nlist.view(bsz, -1).unsqueeze(-1).expand(-1, -1, 3)
coord_r = torch.gather(coord, 1, index)
coord_r = coord_r.view(bsz, natoms, nnei, 3)
diff = coord_r - coord_l
length = torch.linalg.norm(diff, dim=-1, keepdim=True)

Check warning on line 46 in deepmd/pt/model/descriptor/env_mat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/env_mat.py#L40-L46

Added lines #L40 - L46 were not covered by tests
# for index 0 nloc atom
length = length + ~mask.unsqueeze(-1)
t0 = 1 / length
weight = compute_smooth_weight(length, ruct_smth, rcut)
weight = weight * mask.unsqueeze(-1)
env_mat_se_r = t0 * weight
return env_mat_se_r, diff * mask.unsqueeze(-1), weight

Check warning on line 53 in deepmd/pt/model/descriptor/env_mat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/env_mat.py#L48-L53

Added lines #L48 - L53 were not covered by tests


def prod_env_mat_se_a(
extended_coord, nlist, atype, mean, stddev, rcut: float, rcut_smth: float
):
Expand Down Expand Up @@ -58,3 +81,33 @@
t_std = stddev[atype] # [n_atom, dim, 4]
env_mat_se_a = (_env_mat_se_a - t_avg) / t_std
return env_mat_se_a, diff, switch


def prod_env_mat_se_r(

Check warning on line 86 in deepmd/pt/model/descriptor/env_mat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/env_mat.py#L86

Added line #L86 was not covered by tests
extended_coord, nlist, atype, mean, stddev, rcut: float, rcut_smth: float
):
"""Generate smooth environment matrix from atom coordinates and other context.

Args:
- extended_coord: Copied atom coordinates with shape [nframes, nall*3].
- atype: Atom types with shape [nframes, nloc].
- natoms: Batched atom statisics with shape [len(sec)+2].
- box: Batched simulation box with shape [nframes, 9].
- mean: Average value of descriptor per element type with shape [len(sec), nnei, 1].
- stddev: Standard deviation of descriptor per element type with shape [len(sec), nnei, 1].
- deriv_stddev: StdDev of descriptor derivative per element type with shape [len(sec), nnei, 1, 3].
- rcut: Cut-off radius.
- rcut_smth: Smooth hyper-parameter for pair force & energy.

Returns
-------
- env_mat_se_r: Shape is [nframes, natoms[1]*nnei*1].
"""
nframes = extended_coord.shape[0]
Fixed Show fixed Hide fixed
_env_mat_se_r, diff, switch = _make_env_mat_se_r(

Check warning on line 107 in deepmd/pt/model/descriptor/env_mat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/env_mat.py#L106-L107

Added lines #L106 - L107 were not covered by tests
nlist, extended_coord, rcut, rcut_smth
) # shape [n_atom, dim, 1]
t_avg = mean[atype] # [n_atom, dim, 1]
t_std = stddev[atype] # [n_atom, dim, 1]
env_mat_se_r = (_env_mat_se_r - t_avg) / t_std
return env_mat_se_r, diff, switch

Check warning on line 113 in deepmd/pt/model/descriptor/env_mat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/env_mat.py#L110-L113

Added lines #L110 - L113 were not covered by tests
Loading
Loading