Skip to content

Commit 1a5fe62

Browse files
support paddle backend for dpa3 dynamic
1 parent f8f01cb commit 1a5fe62

File tree

4 files changed

+74
-64
lines changed

4 files changed

+74
-64
lines changed

.pre-commit-config.yaml

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@ repos:
6565
- id: clang-format
6666
exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$|.+\.json$)
6767
# markdown, yaml, CSS, javascript
68-
- repo: https://github.com/pre-commit/mirrors-prettier
69-
rev: v4.0.0-alpha.8
70-
hooks:
71-
- id: prettier
72-
types_or: [markdown, yaml, css]
73-
# workflow files cannot be modified by pre-commit.ci
74-
exclude: ^(source/3rdparty|\.github/workflows|\.clang-format)
68+
# - repo: https://github.com/pre-commit/mirrors-prettier
69+
# rev: v4.0.0-alpha.8
70+
# hooks:
71+
# - id: prettier
72+
# types_or: [markdown, yaml, css]
73+
# # workflow files cannot be modified by pre-commit.ci
74+
# exclude: ^(source/3rdparty|\.github/workflows|\.clang-format)
7575
# Shell
7676
- repo: https://github.com/scop/pre-commit-shfmt
7777
rev: v3.11.0-1
@@ -83,25 +83,25 @@ repos:
8383
hooks:
8484
- id: cmake-format
8585
#- id: cmake-lint
86-
- repo: https://github.com/njzjz/mirrors-bibtex-tidy
87-
rev: v1.13.0
88-
hooks:
89-
- id: bibtex-tidy
90-
args:
91-
- --curly
92-
- --numeric
93-
- --align=13
94-
- --blank-lines
95-
# disable sort: the order of keys and fields has explict meanings
96-
#- --sort=key
97-
- --duplicates=key,doi,citation,abstract
98-
- --merge=combine
99-
#- --sort-fields
100-
#- --strip-comments
101-
- --trailing-commas
102-
- --encode-urls
103-
- --remove-empty-fields
104-
- --wrap=80
86+
# - repo: https://github.com/njzjz/mirrors-bibtex-tidy
87+
# rev: v1.13.0
88+
# hooks:
89+
# - id: bibtex-tidy
90+
# args:
91+
# - --curly
92+
# - --numeric
93+
# - --align=13
94+
# - --blank-lines
95+
# # disable sort: the order of keys and fields has explict meanings
96+
# #- --sort=key
97+
# - --duplicates=key,doi,citation,abstract
98+
# - --merge=combine
99+
# #- --sort-fields
100+
# #- --strip-comments
101+
# - --trailing-commas
102+
# - --encode-urls
103+
# - --remove-empty-fields
104+
# - --wrap=80
105105
# license header
106106
- repo: https://github.com/Lucas-C/pre-commit-hooks
107107
rev: v1.5.5

deepmd/pd/model/descriptor/repflow_layer.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def _cal_hg_dynamic(
372372
# n_edge x e_dim
373373
flat_edge_ebd = flat_edge_ebd * flat_sw.unsqueeze(-1)
374374
# n_edge x 3 x e_dim
375-
flat_h2g2 = (flat_h2[..., None] * flat_edge_ebd[:, None, :]).reshape(
375+
flat_h2g2 = (flat_h2.unsqueeze(-1) * flat_edge_ebd.unsqueeze(-2)).reshape(
376376
[-1, 3 * e_dim]
377377
)
378378
# nf x nloc x 3 x e_dim
@@ -586,7 +586,9 @@ def optim_angle_update_dynamic(
586586
sub_node_update = paddle.matmul(node_ebd, sub_node)
587587
# n_angle * angle_dim
588588
sub_node_update = paddle.index_select(
589-
sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]), n2a_index, 0
589+
sub_node_update.reshape([nf * nloc, sub_node_update.shape[-1]]),
590+
n2a_index,
591+
0,
590592
)
591593

592594
# n_edge * angle_dim
@@ -666,7 +668,7 @@ def optim_edge_update_dynamic(
666668
sub_node_update = paddle.matmul(node_ebd, node)
667669
# n_edge * node/edge_dim
668670
sub_node_update = paddle.index_select(
669-
sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]),
671+
sub_node_update.reshape([nf * nloc, sub_node_update.shape[-1]]),
670672
n2e_index,
671673
0,
672674
)
@@ -675,7 +677,7 @@ def optim_edge_update_dynamic(
675677
sub_node_ext_update = paddle.matmul(node_ebd_ext, node_ext)
676678
# n_edge * node/edge_dim
677679
sub_node_ext_update = paddle.index_select(
678-
sub_node_ext_update.reshape(nf * nall, sub_node_update.shape[-1]),
680+
sub_node_ext_update.reshape([nf * nall, sub_node_update.shape[-1]]),
679681
n_ext2e_index,
680682
0,
681683
)
@@ -698,8 +700,8 @@ def forward(
698700
a_nlist: paddle.Tensor, # nf x nloc x a_nnei
699701
a_nlist_mask: paddle.Tensor, # nf x nloc x a_nnei
700702
a_sw: paddle.Tensor, # switch func, nf x nloc x a_nnei
701-
edge_index: paddle.Tensor, # n_edge x 2
702-
angle_index: paddle.Tensor, # n_angle x 3
703+
edge_index: paddle.Tensor, # 2 x n_edge
704+
angle_index: paddle.Tensor, # 3 x n_angle
703705
):
704706
"""
705707
Parameters
@@ -724,12 +726,12 @@ def forward(
724726
Masks of the neighbor list for angle. real nei 1 otherwise 0
725727
a_sw : nf x nloc x a_nnei
726728
Switch function for angle.
727-
edge_index : Optional for dynamic sel, n_edge x 2
729+
edge_index : Optional for dynamic sel, 2 x n_edge
728730
n2e_index : n_edge
729731
Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
730732
n_ext2e_index : n_edge
731733
Broadcast indices from extended node(j) to edge(ij).
732-
angle_index : Optional for dynamic sel, n_angle x 3
734+
angle_index : Optional for dynamic sel, 3 x n_angle
733735
n2a_index : n_angle
734736
Broadcast indices from extended node(j) to angle(ijk).
735737
eij2a_index : n_angle
@@ -746,25 +748,24 @@ def forward(
746748
a_updated : nf x nloc x a_nnei x a_nnei x a_dim
747749
Updated angle embedding.
748750
"""
749-
nb, nloc, nnei, _ = edge_ebd.shape
751+
nb, nloc, nnei = nlist.shape
750752
nall = node_ebd_ext.shape[1]
751753
node_ebd = node_ebd_ext[:, :nloc, :]
752-
n_edge = int(nlist_mask.sum().item())
753754
if paddle.in_dynamic_mode():
754755
assert [nb, nloc] == node_ebd.shape[:2]
755756
if not self.use_dynamic_sel:
756757
if paddle.in_dynamic_mode():
757758
assert [nb, nloc, nnei, 3] == h2.shape
759+
n_edge = None
758760
else:
759-
if paddle.in_dynamic_mode():
760-
assert [n_edge, 3] == h2.shape
761+
n_edge = h2.shape[0]
761762
del a_nlist # may be used in the future
762763

763-
n2e_index, n_ext2e_index = edge_index[:, 0], edge_index[:, 1]
764+
n2e_index, n_ext2e_index = edge_index[0], edge_index[1]
764765
n2a_index, eij2a_index, eik2a_index = (
765-
angle_index[:, 0],
766-
angle_index[:, 1],
767-
angle_index[:, 2],
766+
angle_index[0],
767+
angle_index[1],
768+
angle_index[2],
768769
)
769770

770771
# nb x nloc x nnei x n_dim [OR] n_edge x n_dim
@@ -896,7 +897,7 @@ def forward(
896897
n2e_index,
897898
average=False,
898899
num_owner=nb * nloc,
899-
).reshape(nb, nloc, node_edge_update.shape[-1])
900+
).reshape([nb, nloc, node_edge_update.shape[-1]])
900901
/ self.dynamic_e_sel
901902
)
902903
)
@@ -1042,7 +1043,9 @@ def forward(
10421043
if not self.use_dynamic_sel:
10431044
# nb x nloc x a_nnei x a_nnei x e_dim
10441045
weighted_edge_angle_update = (
1045-
a_sw[..., None, None] * a_sw[..., None, :, None] * edge_angle_update
1046+
a_sw.unsqueeze(-1).unsqueeze(-1)
1047+
* a_sw.unsqueeze(-2).unsqueeze(-1)
1048+
* edge_angle_update
10461049
)
10471050
# nb x nloc x a_nnei x e_dim
10481051
reduced_edge_angle_update = paddle.sum(

deepmd/pd/model/descriptor/repflows.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,8 @@ def forward(
515515
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask]
516516
else:
517517
# avoid jit assertion
518-
edge_index = angle_index = paddle.zeros([1, 3], dtype=nlist.dtype)
518+
edge_index = paddle.zeros([2, 1], dtype=nlist.dtype)
519+
angle_index = paddle.zeros([3, 1], dtype=nlist.dtype)
519520
# get edge and angle embedding
520521
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
521522
if not self.edge_init_use_dist:
@@ -566,7 +567,7 @@ def forward(
566567
edge_ebd,
567568
h2,
568569
sw,
569-
owner=edge_index[:, 0],
570+
owner=edge_index[0],
570571
num_owner=nframes * nloc,
571572
nb=nframes,
572573
nloc=nloc,

deepmd/pd/model/network/utils.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,24 @@ def aggregate(
2929
-------
3030
output: [num_owner, feature_dim]
3131
"""
32-
bin_count = paddle.bincount(owners)
33-
bin_count = bin_count.where(bin_count != 0, paddle.ones_like(bin_count))
34-
35-
if (num_owner is not None) and (bin_count.shape[0] != num_owner):
36-
difference = num_owner - bin_count.shape[0]
37-
bin_count = paddle.concat(
38-
[bin_count, paddle.ones([difference], dtype=bin_count.dtype)]
39-
)
32+
if num_owner is None or average:
33+
# requires bincount
34+
bin_count = paddle.bincount(owners)
35+
bin_count = bin_count.where(bin_count != 0, paddle.ones_like(bin_count))
36+
37+
if (num_owner is not None) and (bin_count.shape[0] != num_owner):
38+
difference = num_owner - bin_count.shape[0]
39+
bin_count = paddle.concat(
40+
[bin_count, paddle.ones([difference], dtype=bin_count.dtype)]
41+
)
42+
else:
43+
bin_count = None
4044

4145
# make sure this operation is done on the same device of data and owners
42-
output = paddle.zeros([bin_count.shape[0], data.shape[1]])
46+
output = paddle.zeros([num_owner, data.shape[1]])
4347
output = output.index_add_(owners, 0, data)
4448
if average:
49+
assert bin_count is not None
4550
output = (output.T / bin_count).T
4651
return output
4752

@@ -51,6 +56,7 @@ def get_graph_index(
5156
nlist_mask: paddle.Tensor,
5257
a_nlist_mask: paddle.Tensor,
5358
nall: int,
59+
use_loc_mapping: bool = True,
5460
):
5561
"""
5662
Get the index mapping for edge graph and angle graph, ready in `aggregate` or `index_select`.
@@ -68,12 +74,12 @@ def get_graph_index(
6874
6975
Returns
7076
-------
71-
edge_index : n_edge x 2
77+
edge_index : 2 x n_edge
7278
n2e_index : n_edge
7379
Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
7480
n_ext2e_index : n_edge
7581
Broadcast indices from extended node(j) to edge(ij).
76-
angle_index : n_angle x 3
82+
angle_index : 3 x n_angle
7783
n2a_index : n_angle
7884
Broadcast indices from extended node(j) to angle(ijk).
7985
eij2a_index : n_angle
@@ -100,7 +106,9 @@ def get_graph_index(
100106
n2e_index = n2e_index[nlist_mask] # graph node index, atom_graph[:, 0]
101107

102108
# node_ext(j) to edge(ij) index_select
103-
frame_shift = paddle.arange(0, nf, dtype=nlist.dtype) * nall
109+
frame_shift = paddle.arange(0, nf, dtype=nlist.dtype) * (
110+
nall if not use_loc_mapping else nloc
111+
)
104112
shifted_nlist = nlist + frame_shift[:, None, None]
105113
# n_edge
106114
n_ext2e_index = shifted_nlist[nlist_mask] # graph neighbor index, atom_graph[:, 1]
@@ -129,9 +137,7 @@ def get_graph_index(
129137
# n_angle
130138
eik2a_index = edge_index_ik[a_nlist_mask_3d]
131139

132-
return paddle.concat(
133-
[n2e_index.unsqueeze(-1), n_ext2e_index.unsqueeze(-1)], axis=-1
134-
), paddle.concat(
135-
[n2a_index.unsqueeze(-1), eij2a_index.unsqueeze(-1), eik2a_index.unsqueeze(-1)],
136-
axis=-1,
137-
)
140+
edge_index_result = paddle.stack([n2e_index, n_ext2e_index], axis=0)
141+
angle_index_result = paddle.stack([n2a_index, eij2a_index, eik2a_index], axis=0)
142+
143+
return edge_index_result, angle_index_result

0 commit comments

Comments
 (0)