Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 21 additions & 18 deletions deepmd/pd/model/descriptor/repflow_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def _cal_hg_dynamic(
# n_edge x e_dim
flat_edge_ebd = flat_edge_ebd * flat_sw.unsqueeze(-1)
# n_edge x 3 x e_dim
flat_h2g2 = (flat_h2[..., None] * flat_edge_ebd[:, None, :]).reshape(
flat_h2g2 = (flat_h2.unsqueeze(-1) * flat_edge_ebd.unsqueeze(-2)).reshape(
[-1, 3 * e_dim]
)
# nf x nloc x 3 x e_dim
Expand Down Expand Up @@ -602,7 +602,9 @@ def optim_angle_update_dynamic(
sub_node_update = paddle.matmul(node_ebd, sub_node)
# n_angle * angle_dim
sub_node_update = paddle.index_select(
sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]), n2a_index, 0
sub_node_update.reshape([nf * nloc, sub_node_update.shape[-1]]),
n2a_index,
0,
)

# n_edge * angle_dim
Expand Down Expand Up @@ -682,7 +684,7 @@ def optim_edge_update_dynamic(
sub_node_update = paddle.matmul(node_ebd, node)
# n_edge * node/edge_dim
sub_node_update = paddle.index_select(
sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]),
sub_node_update.reshape([nf * nloc, sub_node_update.shape[-1]]),
n2e_index,
0,
)
Expand All @@ -691,7 +693,7 @@ def optim_edge_update_dynamic(
sub_node_ext_update = paddle.matmul(node_ebd_ext, node_ext)
# n_edge * node/edge_dim
sub_node_ext_update = paddle.index_select(
sub_node_ext_update.reshape(nf * nall, sub_node_update.shape[-1]),
sub_node_ext_update.reshape([nf * nall, sub_node_update.shape[-1]]),
n_ext2e_index,
0,
)
Expand All @@ -714,8 +716,8 @@ def forward(
a_nlist: paddle.Tensor, # nf x nloc x a_nnei
a_nlist_mask: paddle.Tensor, # nf x nloc x a_nnei
a_sw: paddle.Tensor, # switch func, nf x nloc x a_nnei
edge_index: paddle.Tensor, # n_edge x 2
angle_index: paddle.Tensor, # n_angle x 3
edge_index: paddle.Tensor, # 2 x n_edge
angle_index: paddle.Tensor, # 3 x n_angle
):
"""
Parameters
Expand All @@ -740,12 +742,12 @@ def forward(
Masks of the neighbor list for angle. real nei 1 otherwise 0
a_sw : nf x nloc x a_nnei
Switch function for angle.
edge_index : Optional for dynamic sel, n_edge x 2
edge_index : Optional for dynamic sel, 2 x n_edge
n2e_index : n_edge
Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
n_ext2e_index : n_edge
Broadcast indices from extended node(j) to edge(ij).
angle_index : Optional for dynamic sel, n_angle x 3
angle_index : Optional for dynamic sel, 3 x n_angle
n2a_index : n_angle
Broadcast indices from extended node(j) to angle(ijk).
eij2a_index : n_angle
Expand All @@ -762,25 +764,24 @@ def forward(
a_updated : nf x nloc x a_nnei x a_nnei x a_dim
Updated angle embedding.
"""
nb, nloc, nnei, _ = edge_ebd.shape
nb, nloc, nnei = nlist.shape
nall = node_ebd_ext.shape[1]
node_ebd = node_ebd_ext[:, :nloc, :]
n_edge = int(nlist_mask.sum().item())
if paddle.in_dynamic_mode():
assert [nb, nloc] == node_ebd.shape[:2]
if not self.use_dynamic_sel:
if paddle.in_dynamic_mode():
assert [nb, nloc, nnei, 3] == h2.shape
n_edge = None
else:
if paddle.in_dynamic_mode():
assert [n_edge, 3] == h2.shape
n_edge = h2.shape[0]
del a_nlist # may be used in the future

n2e_index, n_ext2e_index = edge_index[:, 0], edge_index[:, 1]
n2e_index, n_ext2e_index = edge_index[0], edge_index[1]
n2a_index, eij2a_index, eik2a_index = (
angle_index[:, 0],
angle_index[:, 1],
angle_index[:, 2],
angle_index[0],
angle_index[1],
angle_index[2],
)

# nb x nloc x nnei x n_dim [OR] n_edge x n_dim
Expand Down Expand Up @@ -912,7 +913,7 @@ def forward(
n2e_index,
average=False,
num_owner=nb * nloc,
).reshape(nb, nloc, node_edge_update.shape[-1])
).reshape([nb, nloc, node_edge_update.shape[-1]])
/ self.dynamic_e_sel
)
)
Expand Down Expand Up @@ -1058,7 +1059,9 @@ def forward(
if not self.use_dynamic_sel:
# nb x nloc x a_nnei x a_nnei x e_dim
weighted_edge_angle_update = (
a_sw[..., None, None] * a_sw[..., None, :, None] * edge_angle_update
a_sw.unsqueeze(-1).unsqueeze(-1)
* a_sw.unsqueeze(-2).unsqueeze(-1)
* edge_angle_update
)
# nb x nloc x a_nnei x e_dim
reduced_edge_angle_update = paddle.sum(
Expand Down
11 changes: 7 additions & 4 deletions deepmd/pd/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,8 @@ def forward(
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask]
else:
# avoid jit assertion
edge_index = angle_index = paddle.zeros([1, 3], dtype=nlist.dtype)
edge_index = paddle.zeros([2, 1], dtype=nlist.dtype)
angle_index = paddle.zeros([3, 1], dtype=nlist.dtype)
# get edge and angle embedding
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
if not self.edge_init_use_dist:
Expand All @@ -551,8 +552,10 @@ def forward(
# node_ebd_ext: nb x nall x n_dim [OR] nb x nloc x n_dim when not parallel_mode
if not parallel_mode:
assert mapping is not None
node_ebd_ext = paddle.take_along_axis(
node_ebd, mapping, 1, broadcast=False
node_ebd_ext = (
paddle.take_along_axis(node_ebd, mapping, 1, broadcast=False)
if not self.use_loc_mapping
else node_ebd
)
else:
raise NotImplementedError("Not implemented")
Expand All @@ -579,7 +582,7 @@ def forward(
edge_ebd,
h2,
sw,
owner=edge_index[:, 0],
owner=edge_index[0],
num_owner=nframes * nloc,
nb=nframes,
nloc=nloc,
Expand Down
44 changes: 25 additions & 19 deletions deepmd/pd/model/network/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,24 @@
-------
output: [num_owner, feature_dim]
"""
bin_count = paddle.bincount(owners)
bin_count = bin_count.where(bin_count != 0, paddle.ones_like(bin_count))

if (num_owner is not None) and (bin_count.shape[0] != num_owner):
difference = num_owner - bin_count.shape[0]
bin_count = paddle.concat(
[bin_count, paddle.ones([difference], dtype=bin_count.dtype)]
)
if num_owner is None or average:
# requires bincount
bin_count = paddle.bincount(owners)
bin_count = bin_count.where(bin_count != 0, paddle.ones_like(bin_count))

Check warning on line 35 in deepmd/pd/model/network/utils.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/model/network/utils.py#L34-L35

Added lines #L34 - L35 were not covered by tests

if (num_owner is not None) and (bin_count.shape[0] != num_owner):
difference = num_owner - bin_count.shape[0]
bin_count = paddle.concat(

Check warning on line 39 in deepmd/pd/model/network/utils.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/model/network/utils.py#L37-L39

Added lines #L37 - L39 were not covered by tests
[bin_count, paddle.ones([difference], dtype=bin_count.dtype)]
)
else:
bin_count = None

# make sure this operation is done on the same device of data and owners
output = paddle.zeros([bin_count.shape[0], data.shape[1]])
output = output.index_add_(owners, 0, data)
output = paddle.zeros([num_owner, data.shape[1]])
output = output.index_add_(owners, 0, data.astype(output.dtype))
if average:
assert bin_count is not None

Check warning on line 49 in deepmd/pd/model/network/utils.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/model/network/utils.py#L49

Added line #L49 was not covered by tests
output = (output.T / bin_count).T
return output

Expand All @@ -51,6 +56,7 @@
nlist_mask: paddle.Tensor,
a_nlist_mask: paddle.Tensor,
nall: int,
use_loc_mapping: bool = True,
):
"""
Get the index mapping for edge graph and angle graph, ready in `aggregate` or `index_select`.
Expand All @@ -68,12 +74,12 @@

Returns
-------
edge_index : n_edge x 2
edge_index : 2 x n_edge
n2e_index : n_edge
Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
n_ext2e_index : n_edge
Broadcast indices from extended node(j) to edge(ij).
angle_index : n_angle x 3
angle_index : 3 x n_angle
n2a_index : n_angle
Broadcast indices from extended node(j) to angle(ijk).
eij2a_index : n_angle
Expand All @@ -100,7 +106,9 @@
n2e_index = n2e_index[nlist_mask] # graph node index, atom_graph[:, 0]

# node_ext(j) to edge(ij) index_select
frame_shift = paddle.arange(0, nf, dtype=nlist.dtype) * nall
frame_shift = paddle.arange(0, nf, dtype=nlist.dtype) * (
nall if not use_loc_mapping else nloc
)
shifted_nlist = nlist + frame_shift[:, None, None]
# n_edge
n_ext2e_index = shifted_nlist[nlist_mask] # graph neighbor index, atom_graph[:, 1]
Expand Down Expand Up @@ -129,9 +137,7 @@
# n_angle
eik2a_index = edge_index_ik[a_nlist_mask_3d]

return paddle.concat(
[n2e_index.unsqueeze(-1), n_ext2e_index.unsqueeze(-1)], axis=-1
), paddle.concat(
[n2a_index.unsqueeze(-1), eij2a_index.unsqueeze(-1), eik2a_index.unsqueeze(-1)],
axis=-1,
)
edge_index_result = paddle.stack([n2e_index, n_ext2e_index], axis=0)
angle_index_result = paddle.stack([n2a_index, eij2a_index, eik2a_index], axis=0)

return edge_index_result, angle_index_result
155 changes: 155 additions & 0 deletions source/tests/pd/model/test_dynamic_sel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import itertools
import unittest

import numpy as np
import paddle

from deepmd.dpmodel.descriptor.dpa3 import (
RepFlowArgs,
)
from deepmd.pd.model.descriptor import (
DescrptDPA3,
)
from deepmd.pd.utils import (
env,
)
from deepmd.pd.utils.env import (
PRECISION_DICT,
)

from ...seed import (
GLOBAL_SEED,
)
from .test_env_mat import (
TestCaseSingleFrameWithNlist,
)

dtype = env.GLOBAL_PD_FLOAT_PRECISION


class TestDescrptDPA3DynamicSel(unittest.TestCase, TestCaseSingleFrameWithNlist):
def setUp(self) -> None:
TestCaseSingleFrameWithNlist.setUp(self)

def test_consistency(
self,
) -> None:
rng = np.random.default_rng(100)
nf, nloc, nnei = self.nlist.shape
davg = rng.normal(size=(self.nt, nnei, 4))
dstd = rng.normal(size=(self.nt, nnei, 4))
dstd = 0.1 + np.abs(dstd)

for (
ua,
rus,
ruri,
acr,
nme,
prec,
ect,
optim,
) in itertools.product(
[True, False], # update_angle
["res_residual"], # update_style
["norm", "const"], # update_residual_init
[0, 1], # a_compress_rate
[1, 2], # n_multi_edge_message
["float64"], # precision
[False], # use_econf_tebd
[True, False], # optim_update
):
dtype = PRECISION_DICT[prec]
# rtol, atol = get_tols(prec)
rtol, atol = 1e-5, 1e-7
if prec == "float64":
atol = 1e-8 # marginal GPU test cases...

repflow = RepFlowArgs(
n_dim=20,
e_dim=10,
a_dim=10,
nlayers=3,
e_rcut=self.rcut,
e_rcut_smth=self.rcut_smth,
e_sel=nnei,
a_rcut=self.rcut - 0.1,
a_rcut_smth=self.rcut_smth,
a_sel=nnei,
a_compress_rate=acr,
n_multi_edge_message=nme,
axis_neuron=4,
update_angle=ua,
update_style=rus,
update_residual_init=ruri,
optim_update=optim,
smooth_edge_update=True,
sel_reduce_factor=1.0, # test consistent when sel_reduce_factor == 1.0
)

# dpa3 new impl
dd0 = DescrptDPA3(
self.nt,
repflow=repflow,
# kwargs for descriptor
exclude_types=[],
precision=prec,
use_econf_tebd=ect,
type_map=["O", "H"] if ect else None,
seed=GLOBAL_SEED,
).to(env.DEVICE)

repflow.use_dynamic_sel = True

# dpa3 new impl
dd1 = DescrptDPA3(
self.nt,
repflow=repflow,
# kwargs for descriptor
exclude_types=[],
precision=prec,
use_econf_tebd=ect,
type_map=["O", "H"] if ect else None,
seed=GLOBAL_SEED,
).to(env.DEVICE)

dd0.repflows.mean = paddle.to_tensor(davg, dtype=dtype).to(
device=env.DEVICE
)
dd0.repflows.stddev = paddle.to_tensor(dstd, dtype=dtype).to(
device=env.DEVICE
)
rd0, _, _, _, _ = dd0(
paddle.to_tensor(self.coord_ext, dtype=dtype).to(device=env.DEVICE),
paddle.to_tensor(self.atype_ext, dtype=paddle.int64).to(
device=env.DEVICE
),
paddle.to_tensor(self.nlist, dtype=paddle.int64).to(device=env.DEVICE),
paddle.to_tensor(self.mapping, dtype=paddle.int64).to(
device=env.DEVICE
),
)
# serialization
dd1.repflows.mean = paddle.to_tensor(davg, dtype=dtype).to(
device=env.DEVICE
)
dd1.repflows.stddev = paddle.to_tensor(dstd, dtype=dtype).to(
device=env.DEVICE
)
rd1, _, _, _, _ = dd1(
paddle.to_tensor(self.coord_ext, dtype=dtype).to(device=env.DEVICE),
paddle.to_tensor(self.atype_ext, dtype=paddle.int64).to(
device=env.DEVICE
),
paddle.to_tensor(self.nlist, dtype=paddle.int64).to(device=env.DEVICE),
paddle.to_tensor(self.mapping, dtype=paddle.int64).to(
device=env.DEVICE
),
)
np.testing.assert_allclose(
rd0.numpy(),
rd1.numpy(),
rtol=rtol,
atol=atol,
)