Skip to content

Commit 1eefc8e

Browse files
pd: support dpa3 dynamic shape for pd backend (#4828)
support running `input_torch_dynamic.json` with paddle backend(including CINN) TODO list: - [x] PaddlePaddle/Paddle#73601 - [x] PaddlePaddle/Paddle#73622 - [x] PaddlePaddle/Paddle#73737 - [x] PaddlePaddle/Paddle#73747 - [x] PaddlePaddle/Paddle#73809 - [x] PaddlePaddle/Paddle#73761 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Summary by CodeRabbit * **Bug Fixes** * Resolved issues with tensor shape and indexing consistency, preventing assertion errors during model execution. * Improved handling of default tensor initialization to avoid JIT assertion issues. * **Refactor** * Standardized tensor dimension handling and broadcasting for improved clarity and maintainability. * Enhanced code readability with clearer indexing conventions and formatting. * Updated aggregation logic for safer and more efficient tensor operations. * **New Features** * Added an option to control graph index mapping behavior for greater flexibility in advanced use cases. * **Tests** * Introduced comprehensive tests validating descriptor model consistency with dynamic selection enabled. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent c151e04 commit 1eefc8e

File tree

4 files changed

+208
-41
lines changed

4 files changed

+208
-41
lines changed

deepmd/pd/model/descriptor/repflow_layer.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def _cal_hg_dynamic(
388388
# n_edge x e_dim
389389
flat_edge_ebd = flat_edge_ebd * flat_sw.unsqueeze(-1)
390390
# n_edge x 3 x e_dim
391-
flat_h2g2 = (flat_h2[..., None] * flat_edge_ebd[:, None, :]).reshape(
391+
flat_h2g2 = (flat_h2.unsqueeze(-1) * flat_edge_ebd.unsqueeze(-2)).reshape(
392392
[-1, 3 * e_dim]
393393
)
394394
# nf x nloc x 3 x e_dim
@@ -602,7 +602,9 @@ def optim_angle_update_dynamic(
602602
sub_node_update = paddle.matmul(node_ebd, sub_node)
603603
# n_angle * angle_dim
604604
sub_node_update = paddle.index_select(
605-
sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]), n2a_index, 0
605+
sub_node_update.reshape([nf * nloc, sub_node_update.shape[-1]]),
606+
n2a_index,
607+
0,
606608
)
607609

608610
# n_edge * angle_dim
@@ -682,7 +684,7 @@ def optim_edge_update_dynamic(
682684
sub_node_update = paddle.matmul(node_ebd, node)
683685
# n_edge * node/edge_dim
684686
sub_node_update = paddle.index_select(
685-
sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]),
687+
sub_node_update.reshape([nf * nloc, sub_node_update.shape[-1]]),
686688
n2e_index,
687689
0,
688690
)
@@ -691,7 +693,7 @@ def optim_edge_update_dynamic(
691693
sub_node_ext_update = paddle.matmul(node_ebd_ext, node_ext)
692694
# n_edge * node/edge_dim
693695
sub_node_ext_update = paddle.index_select(
694-
sub_node_ext_update.reshape(nf * nall, sub_node_update.shape[-1]),
696+
sub_node_ext_update.reshape([nf * nall, sub_node_update.shape[-1]]),
695697
n_ext2e_index,
696698
0,
697699
)
@@ -714,8 +716,8 @@ def forward(
714716
a_nlist: paddle.Tensor, # nf x nloc x a_nnei
715717
a_nlist_mask: paddle.Tensor, # nf x nloc x a_nnei
716718
a_sw: paddle.Tensor, # switch func, nf x nloc x a_nnei
717-
edge_index: paddle.Tensor, # n_edge x 2
718-
angle_index: paddle.Tensor, # n_angle x 3
719+
edge_index: paddle.Tensor, # 2 x n_edge
720+
angle_index: paddle.Tensor, # 3 x n_angle
719721
):
720722
"""
721723
Parameters
@@ -740,12 +742,12 @@ def forward(
740742
Masks of the neighbor list for angle. real nei 1 otherwise 0
741743
a_sw : nf x nloc x a_nnei
742744
Switch function for angle.
743-
edge_index : Optional for dynamic sel, n_edge x 2
745+
edge_index : Optional for dynamic sel, 2 x n_edge
744746
n2e_index : n_edge
745747
Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
746748
n_ext2e_index : n_edge
747749
Broadcast indices from extended node(j) to edge(ij).
748-
angle_index : Optional for dynamic sel, n_angle x 3
750+
angle_index : Optional for dynamic sel, 3 x n_angle
749751
n2a_index : n_angle
750752
Broadcast indices from extended node(j) to angle(ijk).
751753
eij2a_index : n_angle
@@ -762,25 +764,24 @@ def forward(
762764
a_updated : nf x nloc x a_nnei x a_nnei x a_dim
763765
Updated angle embedding.
764766
"""
765-
nb, nloc, nnei, _ = edge_ebd.shape
767+
nb, nloc, nnei = nlist.shape
766768
nall = node_ebd_ext.shape[1]
767769
node_ebd = node_ebd_ext[:, :nloc, :]
768-
n_edge = int(nlist_mask.sum().item())
769770
if paddle.in_dynamic_mode():
770771
assert [nb, nloc] == node_ebd.shape[:2]
771772
if not self.use_dynamic_sel:
772773
if paddle.in_dynamic_mode():
773774
assert [nb, nloc, nnei, 3] == h2.shape
775+
n_edge = None
774776
else:
775-
if paddle.in_dynamic_mode():
776-
assert [n_edge, 3] == h2.shape
777+
n_edge = h2.shape[0]
777778
del a_nlist # may be used in the future
778779

779-
n2e_index, n_ext2e_index = edge_index[:, 0], edge_index[:, 1]
780+
n2e_index, n_ext2e_index = edge_index[0], edge_index[1]
780781
n2a_index, eij2a_index, eik2a_index = (
781-
angle_index[:, 0],
782-
angle_index[:, 1],
783-
angle_index[:, 2],
782+
angle_index[0],
783+
angle_index[1],
784+
angle_index[2],
784785
)
785786

786787
# nb x nloc x nnei x n_dim [OR] n_edge x n_dim
@@ -912,7 +913,7 @@ def forward(
912913
n2e_index,
913914
average=False,
914915
num_owner=nb * nloc,
915-
).reshape(nb, nloc, node_edge_update.shape[-1])
916+
).reshape([nb, nloc, node_edge_update.shape[-1]])
916917
/ self.dynamic_e_sel
917918
)
918919
)
@@ -1058,7 +1059,9 @@ def forward(
10581059
if not self.use_dynamic_sel:
10591060
# nb x nloc x a_nnei x a_nnei x e_dim
10601061
weighted_edge_angle_update = (
1061-
a_sw[..., None, None] * a_sw[..., None, :, None] * edge_angle_update
1062+
a_sw.unsqueeze(-1).unsqueeze(-1)
1063+
* a_sw.unsqueeze(-2).unsqueeze(-1)
1064+
* edge_angle_update
10621065
)
10631066
# nb x nloc x a_nnei x e_dim
10641067
reduced_edge_angle_update = paddle.sum(

deepmd/pd/model/descriptor/repflows.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,8 @@ def forward(
528528
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask]
529529
else:
530530
# avoid jit assertion
531-
edge_index = angle_index = paddle.zeros([1, 3], dtype=nlist.dtype)
531+
edge_index = paddle.zeros([2, 1], dtype=nlist.dtype)
532+
angle_index = paddle.zeros([3, 1], dtype=nlist.dtype)
532533
# get edge and angle embedding
533534
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
534535
if not self.edge_init_use_dist:
@@ -551,8 +552,10 @@ def forward(
551552
# node_ebd_ext: nb x nall x n_dim [OR] nb x nloc x n_dim when not parallel_mode
552553
if not parallel_mode:
553554
assert mapping is not None
554-
node_ebd_ext = paddle.take_along_axis(
555-
node_ebd, mapping, 1, broadcast=False
555+
node_ebd_ext = (
556+
paddle.take_along_axis(node_ebd, mapping, 1, broadcast=False)
557+
if not self.use_loc_mapping
558+
else node_ebd
556559
)
557560
else:
558561
raise NotImplementedError("Not implemented")
@@ -579,7 +582,7 @@ def forward(
579582
edge_ebd,
580583
h2,
581584
sw,
582-
owner=edge_index[:, 0],
585+
owner=edge_index[0],
583586
num_owner=nframes * nloc,
584587
nb=nframes,
585588
nloc=nloc,

deepmd/pd/model/network/utils.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,24 @@ def aggregate(
2929
-------
3030
output: [num_owner, feature_dim]
3131
"""
32-
bin_count = paddle.bincount(owners)
33-
bin_count = bin_count.where(bin_count != 0, paddle.ones_like(bin_count))
34-
35-
if (num_owner is not None) and (bin_count.shape[0] != num_owner):
36-
difference = num_owner - bin_count.shape[0]
37-
bin_count = paddle.concat(
38-
[bin_count, paddle.ones([difference], dtype=bin_count.dtype)]
39-
)
32+
if num_owner is None or average:
33+
# requires bincount
34+
bin_count = paddle.bincount(owners)
35+
bin_count = bin_count.where(bin_count != 0, paddle.ones_like(bin_count))
36+
37+
if (num_owner is not None) and (bin_count.shape[0] != num_owner):
38+
difference = num_owner - bin_count.shape[0]
39+
bin_count = paddle.concat(
40+
[bin_count, paddle.ones([difference], dtype=bin_count.dtype)]
41+
)
42+
else:
43+
bin_count = None
4044

4145
# make sure this operation is done on the same device of data and owners
42-
output = paddle.zeros([bin_count.shape[0], data.shape[1]])
43-
output = output.index_add_(owners, 0, data)
46+
output = paddle.zeros([num_owner, data.shape[1]])
47+
output = output.index_add_(owners, 0, data.astype(output.dtype))
4448
if average:
49+
assert bin_count is not None
4550
output = (output.T / bin_count).T
4651
return output
4752

@@ -51,6 +56,7 @@ def get_graph_index(
5156
nlist_mask: paddle.Tensor,
5257
a_nlist_mask: paddle.Tensor,
5358
nall: int,
59+
use_loc_mapping: bool = True,
5460
):
5561
"""
5662
Get the index mapping for edge graph and angle graph, ready in `aggregate` or `index_select`.
@@ -68,12 +74,12 @@ def get_graph_index(
6874
6975
Returns
7076
-------
71-
edge_index : n_edge x 2
77+
edge_index : 2 x n_edge
7278
n2e_index : n_edge
7379
Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
7480
n_ext2e_index : n_edge
7581
Broadcast indices from extended node(j) to edge(ij).
76-
angle_index : n_angle x 3
82+
angle_index : 3 x n_angle
7783
n2a_index : n_angle
7884
Broadcast indices from extended node(j) to angle(ijk).
7985
eij2a_index : n_angle
@@ -100,7 +106,9 @@ def get_graph_index(
100106
n2e_index = n2e_index[nlist_mask] # graph node index, atom_graph[:, 0]
101107

102108
# node_ext(j) to edge(ij) index_select
103-
frame_shift = paddle.arange(0, nf, dtype=nlist.dtype) * nall
109+
frame_shift = paddle.arange(0, nf, dtype=nlist.dtype) * (
110+
nall if not use_loc_mapping else nloc
111+
)
104112
shifted_nlist = nlist + frame_shift[:, None, None]
105113
# n_edge
106114
n_ext2e_index = shifted_nlist[nlist_mask] # graph neighbor index, atom_graph[:, 1]
@@ -129,9 +137,7 @@ def get_graph_index(
129137
# n_angle
130138
eik2a_index = edge_index_ik[a_nlist_mask_3d]
131139

132-
return paddle.concat(
133-
[n2e_index.unsqueeze(-1), n_ext2e_index.unsqueeze(-1)], axis=-1
134-
), paddle.concat(
135-
[n2a_index.unsqueeze(-1), eij2a_index.unsqueeze(-1), eik2a_index.unsqueeze(-1)],
136-
axis=-1,
137-
)
140+
edge_index_result = paddle.stack([n2e_index, n_ext2e_index], axis=0)
141+
angle_index_result = paddle.stack([n2a_index, eij2a_index, eik2a_index], axis=0)
142+
143+
return edge_index_result, angle_index_result
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import itertools
3+
import unittest
4+
5+
import numpy as np
6+
import paddle
7+
8+
from deepmd.dpmodel.descriptor.dpa3 import (
9+
RepFlowArgs,
10+
)
11+
from deepmd.pd.model.descriptor import (
12+
DescrptDPA3,
13+
)
14+
from deepmd.pd.utils import (
15+
env,
16+
)
17+
from deepmd.pd.utils.env import (
18+
PRECISION_DICT,
19+
)
20+
21+
from ...seed import (
22+
GLOBAL_SEED,
23+
)
24+
from .test_env_mat import (
25+
TestCaseSingleFrameWithNlist,
26+
)
27+
28+
dtype = env.GLOBAL_PD_FLOAT_PRECISION
29+
30+
31+
class TestDescrptDPA3DynamicSel(unittest.TestCase, TestCaseSingleFrameWithNlist):
32+
def setUp(self) -> None:
33+
TestCaseSingleFrameWithNlist.setUp(self)
34+
35+
def test_consistency(
36+
self,
37+
) -> None:
38+
rng = np.random.default_rng(100)
39+
nf, nloc, nnei = self.nlist.shape
40+
davg = rng.normal(size=(self.nt, nnei, 4))
41+
dstd = rng.normal(size=(self.nt, nnei, 4))
42+
dstd = 0.1 + np.abs(dstd)
43+
44+
for (
45+
ua,
46+
rus,
47+
ruri,
48+
acr,
49+
nme,
50+
prec,
51+
ect,
52+
optim,
53+
) in itertools.product(
54+
[True, False], # update_angle
55+
["res_residual"], # update_style
56+
["norm", "const"], # update_residual_init
57+
[0, 1], # a_compress_rate
58+
[1, 2], # n_multi_edge_message
59+
["float64"], # precision
60+
[False], # use_econf_tebd
61+
[True, False], # optim_update
62+
):
63+
dtype = PRECISION_DICT[prec]
64+
# rtol, atol = get_tols(prec)
65+
rtol, atol = 1e-5, 1e-7
66+
if prec == "float64":
67+
atol = 1e-8 # marginal GPU test cases...
68+
69+
repflow = RepFlowArgs(
70+
n_dim=20,
71+
e_dim=10,
72+
a_dim=10,
73+
nlayers=3,
74+
e_rcut=self.rcut,
75+
e_rcut_smth=self.rcut_smth,
76+
e_sel=nnei,
77+
a_rcut=self.rcut - 0.1,
78+
a_rcut_smth=self.rcut_smth,
79+
a_sel=nnei,
80+
a_compress_rate=acr,
81+
n_multi_edge_message=nme,
82+
axis_neuron=4,
83+
update_angle=ua,
84+
update_style=rus,
85+
update_residual_init=ruri,
86+
optim_update=optim,
87+
smooth_edge_update=True,
88+
sel_reduce_factor=1.0, # test consistent when sel_reduce_factor == 1.0
89+
)
90+
91+
# dpa3 new impl
92+
dd0 = DescrptDPA3(
93+
self.nt,
94+
repflow=repflow,
95+
# kwargs for descriptor
96+
exclude_types=[],
97+
precision=prec,
98+
use_econf_tebd=ect,
99+
type_map=["O", "H"] if ect else None,
100+
seed=GLOBAL_SEED,
101+
).to(env.DEVICE)
102+
103+
repflow.use_dynamic_sel = True
104+
105+
# dpa3 new impl
106+
dd1 = DescrptDPA3(
107+
self.nt,
108+
repflow=repflow,
109+
# kwargs for descriptor
110+
exclude_types=[],
111+
precision=prec,
112+
use_econf_tebd=ect,
113+
type_map=["O", "H"] if ect else None,
114+
seed=GLOBAL_SEED,
115+
).to(env.DEVICE)
116+
117+
dd0.repflows.mean = paddle.to_tensor(davg, dtype=dtype).to(
118+
device=env.DEVICE
119+
)
120+
dd0.repflows.stddev = paddle.to_tensor(dstd, dtype=dtype).to(
121+
device=env.DEVICE
122+
)
123+
rd0, _, _, _, _ = dd0(
124+
paddle.to_tensor(self.coord_ext, dtype=dtype).to(device=env.DEVICE),
125+
paddle.to_tensor(self.atype_ext, dtype=paddle.int64).to(
126+
device=env.DEVICE
127+
),
128+
paddle.to_tensor(self.nlist, dtype=paddle.int64).to(device=env.DEVICE),
129+
paddle.to_tensor(self.mapping, dtype=paddle.int64).to(
130+
device=env.DEVICE
131+
),
132+
)
133+
# serialization
134+
dd1.repflows.mean = paddle.to_tensor(davg, dtype=dtype).to(
135+
device=env.DEVICE
136+
)
137+
dd1.repflows.stddev = paddle.to_tensor(dstd, dtype=dtype).to(
138+
device=env.DEVICE
139+
)
140+
rd1, _, _, _, _ = dd1(
141+
paddle.to_tensor(self.coord_ext, dtype=dtype).to(device=env.DEVICE),
142+
paddle.to_tensor(self.atype_ext, dtype=paddle.int64).to(
143+
device=env.DEVICE
144+
),
145+
paddle.to_tensor(self.nlist, dtype=paddle.int64).to(device=env.DEVICE),
146+
paddle.to_tensor(self.mapping, dtype=paddle.int64).to(
147+
device=env.DEVICE
148+
),
149+
)
150+
np.testing.assert_allclose(
151+
rd0.numpy(),
152+
rd1.numpy(),
153+
rtol=rtol,
154+
atol=atol,
155+
)

0 commit comments

Comments
 (0)