From dd70e990022584810f7fea540e48ce02114fa6f0 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 30 Jun 2025 15:12:45 +0800 Subject: [PATCH 1/2] support paddle backend for dpa3 dynamic --- deepmd/pd/model/descriptor/repflow_layer.py | 39 ++++++++++--------- deepmd/pd/model/descriptor/repflows.py | 5 ++- deepmd/pd/model/network/utils.py | 42 ++++++++++++--------- 3 files changed, 48 insertions(+), 38 deletions(-) diff --git a/deepmd/pd/model/descriptor/repflow_layer.py b/deepmd/pd/model/descriptor/repflow_layer.py index 7de88d8bd9..f2822b62a0 100644 --- a/deepmd/pd/model/descriptor/repflow_layer.py +++ b/deepmd/pd/model/descriptor/repflow_layer.py @@ -372,7 +372,7 @@ def _cal_hg_dynamic( # n_edge x e_dim flat_edge_ebd = flat_edge_ebd * flat_sw.unsqueeze(-1) # n_edge x 3 x e_dim - flat_h2g2 = (flat_h2[..., None] * flat_edge_ebd[:, None, :]).reshape( + flat_h2g2 = (flat_h2.unsqueeze(-1) * flat_edge_ebd.unsqueeze(-2)).reshape( [-1, 3 * e_dim] ) # nf x nloc x 3 x e_dim @@ -586,7 +586,9 @@ def optim_angle_update_dynamic( sub_node_update = paddle.matmul(node_ebd, sub_node) # n_angle * angle_dim sub_node_update = paddle.index_select( - sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]), n2a_index, 0 + sub_node_update.reshape([nf * nloc, sub_node_update.shape[-1]]), + n2a_index, + 0, ) # n_edge * angle_dim @@ -666,7 +668,7 @@ def optim_edge_update_dynamic( sub_node_update = paddle.matmul(node_ebd, node) # n_edge * node/edge_dim sub_node_update = paddle.index_select( - sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]), + sub_node_update.reshape([nf * nloc, sub_node_update.shape[-1]]), n2e_index, 0, ) @@ -675,7 +677,7 @@ def optim_edge_update_dynamic( sub_node_ext_update = paddle.matmul(node_ebd_ext, node_ext) # n_edge * node/edge_dim sub_node_ext_update = paddle.index_select( - sub_node_ext_update.reshape(nf * nall, sub_node_update.shape[-1]), + sub_node_ext_update.reshape([nf * nall, sub_node_update.shape[-1]]), n_ext2e_index, 0, ) @@ -698,8 +700,8 @@ def forward( a_nlist: paddle.Tensor, # nf x nloc x a_nnei a_nlist_mask: paddle.Tensor, # nf x nloc x a_nnei a_sw: paddle.Tensor, # switch func, nf x nloc x a_nnei - edge_index: paddle.Tensor, # n_edge x 2 - angle_index: paddle.Tensor, # n_angle x 3 + edge_index: paddle.Tensor, # 2 x n_edge + angle_index: paddle.Tensor, # 3 x n_angle ): """ Parameters @@ -724,12 +726,12 @@ def forward( Masks of the neighbor list for angle. real nei 1 otherwise 0 a_sw : nf x nloc x a_nnei Switch function for angle. - edge_index : Optional for dynamic sel, n_edge x 2 + edge_index : Optional for dynamic sel, 2 x n_edge n2e_index : n_edge Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). n_ext2e_index : n_edge Broadcast indices from extended node(j) to edge(ij). - angle_index : Optional for dynamic sel, n_angle x 3 + angle_index : Optional for dynamic sel, 3 x n_angle n2a_index : n_angle Broadcast indices from extended node(j) to angle(ijk). eij2a_index : n_angle @@ -746,25 +748,24 @@ def forward( a_updated : nf x nloc x a_nnei x a_nnei x a_dim Updated angle embedding. """ - nb, nloc, nnei, _ = edge_ebd.shape + nb, nloc, nnei = nlist.shape nall = node_ebd_ext.shape[1] node_ebd = node_ebd_ext[:, :nloc, :] - n_edge = int(nlist_mask.sum().item()) if paddle.in_dynamic_mode(): assert [nb, nloc] == node_ebd.shape[:2] if not self.use_dynamic_sel: if paddle.in_dynamic_mode(): assert [nb, nloc, nnei, 3] == h2.shape + n_edge = None else: - if paddle.in_dynamic_mode(): - assert [n_edge, 3] == h2.shape + n_edge = h2.shape[0] del a_nlist # may be used in the future - n2e_index, n_ext2e_index = edge_index[:, 0], edge_index[:, 1] + n2e_index, n_ext2e_index = edge_index[0], edge_index[1] n2a_index, eij2a_index, eik2a_index = ( - angle_index[:, 0], - angle_index[:, 1], - angle_index[:, 2], + angle_index[0], + angle_index[1], + angle_index[2], ) # nb x nloc x nnei x n_dim [OR] n_edge x n_dim @@ -896,7 +897,7 @@ def forward( n2e_index, average=False, num_owner=nb * nloc, - ).reshape(nb, nloc, node_edge_update.shape[-1]) + ).reshape([nb, nloc, node_edge_update.shape[-1]]) / self.dynamic_e_sel ) ) @@ -1042,7 +1043,9 @@ def forward( if not self.use_dynamic_sel: # nb x nloc x a_nnei x a_nnei x e_dim weighted_edge_angle_update = ( - a_sw[..., None, None] * a_sw[..., None, :, None] * edge_angle_update + a_sw.unsqueeze(-1).unsqueeze(-1) + * a_sw.unsqueeze(-2).unsqueeze(-1) + * edge_angle_update ) # nb x nloc x a_nnei x e_dim reduced_edge_angle_update = paddle.sum( diff --git a/deepmd/pd/model/descriptor/repflows.py b/deepmd/pd/model/descriptor/repflows.py index 2214116bbc..ca55c82015 100644 --- a/deepmd/pd/model/descriptor/repflows.py +++ b/deepmd/pd/model/descriptor/repflows.py @@ -515,7 +515,8 @@ def forward( a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask] else: # avoid jit assertion - edge_index = angle_index = paddle.zeros([1, 3], dtype=nlist.dtype) + edge_index = paddle.zeros([2, 1], dtype=nlist.dtype) + angle_index = paddle.zeros([3, 1], dtype=nlist.dtype) # get edge and angle embedding # nb x nloc x nnei x e_dim [OR] n_edge x e_dim if not self.edge_init_use_dist: @@ -566,7 +567,7 @@ def forward( edge_ebd, h2, sw, - owner=edge_index[:, 0], + owner=edge_index[0], num_owner=nframes * nloc, nb=nframes, nloc=nloc, diff --git a/deepmd/pd/model/network/utils.py b/deepmd/pd/model/network/utils.py index b2ed3ac24e..ceeb9bbf96 100644 --- a/deepmd/pd/model/network/utils.py +++ b/deepmd/pd/model/network/utils.py @@ -29,19 +29,24 @@ def aggregate( ------- output: [num_owner, feature_dim] """ - bin_count = paddle.bincount(owners) - bin_count = bin_count.where(bin_count != 0, paddle.ones_like(bin_count)) - - if (num_owner is not None) and (bin_count.shape[0] != num_owner): - difference = num_owner - bin_count.shape[0] - bin_count = paddle.concat( - [bin_count, paddle.ones([difference], dtype=bin_count.dtype)] - ) + if num_owner is None or average: + # requires bincount + bin_count = paddle.bincount(owners) + bin_count = bin_count.where(bin_count != 0, paddle.ones_like(bin_count)) + + if (num_owner is not None) and (bin_count.shape[0] != num_owner): + difference = num_owner - bin_count.shape[0] + bin_count = paddle.concat( + [bin_count, paddle.ones([difference], dtype=bin_count.dtype)] + ) + else: + bin_count = None # make sure this operation is done on the same device of data and owners - output = paddle.zeros([bin_count.shape[0], data.shape[1]]) + output = paddle.zeros([num_owner, data.shape[1]]) output = output.index_add_(owners, 0, data) if average: + assert bin_count is not None output = (output.T / bin_count).T return output @@ -51,6 +56,7 @@ def get_graph_index( nlist_mask: paddle.Tensor, a_nlist_mask: paddle.Tensor, nall: int, + use_loc_mapping: bool = True, ): """ Get the index mapping for edge graph and angle graph, ready in `aggregate` or `index_select`. @@ -68,12 +74,12 @@ def get_graph_index( Returns ------- - edge_index : n_edge x 2 + edge_index : 2 x n_edge n2e_index : n_edge Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). n_ext2e_index : n_edge Broadcast indices from extended node(j) to edge(ij). - angle_index : n_angle x 3 + angle_index : 3 x n_angle n2a_index : n_angle Broadcast indices from extended node(j) to angle(ijk). eij2a_index : n_angle @@ -100,7 +106,9 @@ def get_graph_index( n2e_index = n2e_index[nlist_mask] # graph node index, atom_graph[:, 0] # node_ext(j) to edge(ij) index_select - frame_shift = paddle.arange(0, nf, dtype=nlist.dtype) * nall + frame_shift = paddle.arange(0, nf, dtype=nlist.dtype) * ( + nall if not use_loc_mapping else nloc + ) shifted_nlist = nlist + frame_shift[:, None, None] # n_edge n_ext2e_index = shifted_nlist[nlist_mask] # graph neighbor index, atom_graph[:, 1] @@ -129,9 +137,7 @@ def get_graph_index( # n_angle eik2a_index = edge_index_ik[a_nlist_mask_3d] - return paddle.concat( - [n2e_index.unsqueeze(-1), n_ext2e_index.unsqueeze(-1)], axis=-1 - ), paddle.concat( - [n2a_index.unsqueeze(-1), eij2a_index.unsqueeze(-1), eik2a_index.unsqueeze(-1)], - axis=-1, - ) + edge_index_result = paddle.stack([n2e_index, n_ext2e_index], axis=0) + angle_index_result = paddle.stack([n2a_index, eij2a_index, eik2a_index], axis=0) + + return edge_index_result, angle_index_result From 2e2efe7d150c416fd9598f16f57171a70f1467ad Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 9 Jul 2025 14:10:22 +0800 Subject: [PATCH 2/2] fix repflow --- deepmd/pd/model/descriptor/repflows.py | 6 +- deepmd/pd/model/network/utils.py | 2 +- source/tests/pd/model/test_dynamic_sel.py | 155 ++++++++++++++++++++++ 3 files changed, 160 insertions(+), 3 deletions(-) create mode 100644 source/tests/pd/model/test_dynamic_sel.py diff --git a/deepmd/pd/model/descriptor/repflows.py b/deepmd/pd/model/descriptor/repflows.py index ca55c82015..2fe9ff8470 100644 --- a/deepmd/pd/model/descriptor/repflows.py +++ b/deepmd/pd/model/descriptor/repflows.py @@ -539,8 +539,10 @@ def forward( # node_ebd_ext: nb x nall x n_dim [OR] nb x nloc x n_dim when not parallel_mode if not parallel_mode: assert mapping is not None - node_ebd_ext = paddle.take_along_axis( - node_ebd, mapping, 1, broadcast=False + node_ebd_ext = ( + paddle.take_along_axis(node_ebd, mapping, 1, broadcast=False) + if not self.use_loc_mapping + else node_ebd ) else: raise NotImplementedError("Not implemented") diff --git a/deepmd/pd/model/network/utils.py b/deepmd/pd/model/network/utils.py index ceeb9bbf96..9fae72c2cc 100644 --- a/deepmd/pd/model/network/utils.py +++ b/deepmd/pd/model/network/utils.py @@ -44,7 +44,7 @@ def aggregate( # make sure this operation is done on the same device of data and owners output = paddle.zeros([num_owner, data.shape[1]]) - output = output.index_add_(owners, 0, data) + output = output.index_add_(owners, 0, data.astype(output.dtype)) if average: assert bin_count is not None output = (output.T / bin_count).T diff --git a/source/tests/pd/model/test_dynamic_sel.py b/source/tests/pd/model/test_dynamic_sel.py new file mode 100644 index 0000000000..a605d97f85 --- /dev/null +++ b/source/tests/pd/model/test_dynamic_sel.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import paddle + +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) +from deepmd.pd.model.descriptor import ( + DescrptDPA3, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.env import ( + PRECISION_DICT, +) + +from ...seed import ( + GLOBAL_SEED, +) +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +dtype = env.GLOBAL_PD_FLOAT_PRECISION + + +class TestDescrptDPA3DynamicSel(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ) -> None: + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for ( + ua, + rus, + ruri, + acr, + nme, + prec, + ect, + optim, + ) in itertools.product( + [True, False], # update_angle + ["res_residual"], # update_style + ["norm", "const"], # update_residual_init + [0, 1], # a_compress_rate + [1, 2], # n_multi_edge_message + ["float64"], # precision + [False], # use_econf_tebd + [True, False], # optim_update + ): + dtype = PRECISION_DICT[prec] + # rtol, atol = get_tols(prec) + rtol, atol = 1e-5, 1e-7 + if prec == "float64": + atol = 1e-8 # marginal GPU test cases... + + repflow = RepFlowArgs( + n_dim=20, + e_dim=10, + a_dim=10, + nlayers=3, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=nnei, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=nnei, + a_compress_rate=acr, + n_multi_edge_message=nme, + axis_neuron=4, + update_angle=ua, + update_style=rus, + update_residual_init=ruri, + optim_update=optim, + smooth_edge_update=True, + sel_reduce_factor=1.0, # test consistent when sel_reduce_factor == 1.0 + ) + + # dpa3 new impl + dd0 = DescrptDPA3( + self.nt, + repflow=repflow, + # kwargs for descriptor + exclude_types=[], + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + + repflow.use_dynamic_sel = True + + # dpa3 new impl + dd1 = DescrptDPA3( + self.nt, + repflow=repflow, + # kwargs for descriptor + exclude_types=[], + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + + dd0.repflows.mean = paddle.to_tensor(davg, dtype=dtype).to( + device=env.DEVICE + ) + dd0.repflows.stddev = paddle.to_tensor(dstd, dtype=dtype).to( + device=env.DEVICE + ) + rd0, _, _, _, _ = dd0( + paddle.to_tensor(self.coord_ext, dtype=dtype).to(device=env.DEVICE), + paddle.to_tensor(self.atype_ext, dtype=paddle.int64).to( + device=env.DEVICE + ), + paddle.to_tensor(self.nlist, dtype=paddle.int64).to(device=env.DEVICE), + paddle.to_tensor(self.mapping, dtype=paddle.int64).to( + device=env.DEVICE + ), + ) + # serialization + dd1.repflows.mean = paddle.to_tensor(davg, dtype=dtype).to( + device=env.DEVICE + ) + dd1.repflows.stddev = paddle.to_tensor(dstd, dtype=dtype).to( + device=env.DEVICE + ) + rd1, _, _, _, _ = dd1( + paddle.to_tensor(self.coord_ext, dtype=dtype).to(device=env.DEVICE), + paddle.to_tensor(self.atype_ext, dtype=paddle.int64).to( + device=env.DEVICE + ), + paddle.to_tensor(self.nlist, dtype=paddle.int64).to(device=env.DEVICE), + paddle.to_tensor(self.mapping, dtype=paddle.int64).to( + device=env.DEVICE + ), + ) + np.testing.assert_allclose( + rd0.numpy(), + rd1.numpy(), + rtol=rtol, + atol=atol, + )