From 25bb8212610bf281bf74732824f382d74f694478 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 2 Nov 2024 03:27:48 -0400 Subject: [PATCH] feat(jax/array-api): DPA-2 (#4294) ## Summary by CodeRabbit - **New Features** - Introduced new classes for enhanced descriptor functionality, including `DescrptDPA2`, `DescrptBlockRepformers`, and `DescrptBlockSeTTebd`. - Added serialization and deserialization methods for better state management of descriptor objects. - **Improvements** - Enhanced compatibility with various array backends through the integration of `array_api_compat`. - Refactored existing methods to utilize new array API functions for improved performance. - Updated documentation to reflect JAX as a supported backend alongside PyTorch. - **Bug Fixes** - Updated handling of attributes in several classes to ensure correct deserialization and type safety. - **Tests** - Enhanced testing capabilities for JAX and Array API Strict backend integration, including conditional imports and new evaluation methods. --------- Signed-off-by: Jinzhe Zeng Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> --- deepmd/dpmodel/descriptor/dpa2.py | 37 ++- deepmd/dpmodel/descriptor/repformers.py | 311 ++++++++++++------ deepmd/dpmodel/utils/nlist.py | 23 +- deepmd/jax/descriptor/dpa2.py | 61 ++++ deepmd/jax/descriptor/repformers.py | 107 ++++++ doc/model/dpa2.md | 4 +- .../tests/array_api_strict/descriptor/dpa2.py | 57 ++++ .../array_api_strict/descriptor/repformers.py | 98 ++++++ .../tests/consistent/descriptor/test_dpa2.py | 36 ++ 9 files changed, 616 insertions(+), 118 deletions(-) create mode 100644 deepmd/jax/descriptor/dpa2.py create mode 100644 deepmd/jax/descriptor/repformers.py create mode 100644 source/tests/array_api_strict/descriptor/dpa2.py create mode 100644 source/tests/array_api_strict/descriptor/repformers.py diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 1dbb14961e..200747c0ef 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -4,11 +4,18 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( NativeOP, ) +from deepmd.dpmodel.array_api import ( + xp_take_along_axis, +) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.utils import ( EnvMat, NetworkCollection, @@ -787,9 +794,10 @@ def call( The smooth switch function. shape: nf x nloc x nnei """ + xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) use_three_body = self.use_three_body nframes, nloc, nnei = nlist.shape - nall = coord_ext.reshape(nframes, -1).shape[1] // 3 + nall = xp.reshape(coord_ext, (nframes, -1)).shape[1] // 3 # nlists nlist_dict = build_multiple_neighbor_list( coord_ext, @@ -798,7 +806,10 @@ def call( self.nsel_list, ) # repinit - g1_ext = self.type_embedding.call()[atype_ext] + g1_ext = xp.reshape( + xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0), + (nframes, nall, self.tebd_dim), + ) g1_inp = g1_ext[:, :nloc, :] g1, _, _, _, _ = self.repinit( nlist_dict[ @@ -823,7 +834,7 @@ def call( g1_ext, mapping, ) - g1 = np.concatenate([g1, g1_three_body], axis=-1) + g1 = xp.concat([g1, g1_three_body], axis=-1) # linear to change shape g1 = self.g1_shape_tranform(g1) if self.add_tebd_to_repinit_out: @@ -831,8 +842,10 @@ def call( g1 = g1 + self.tebd_transform(g1_inp) # mapping g1 assert mapping is not None - mapping_ext = np.tile(mapping.reshape(nframes, nall, 1), (1, 1, g1.shape[-1])) - g1_ext = np.take_along_axis(g1, mapping_ext, axis=1) + mapping_ext = xp.tile( + xp.reshape(mapping, (nframes, nall, 1)), (1, 1, g1.shape[-1]) + ) + g1_ext = xp_take_along_axis(g1, mapping_ext, axis=1) # repformer g1, g2, h2, rot_mat, sw = self.repformers( nlist_dict[ @@ -846,7 +859,7 @@ def call( mapping, ) if self.concat_output_tebd: - g1 = np.concatenate([g1, g1_inp], axis=-1) + g1 = xp.concat([g1, g1_inp], axis=-1) return g1, rot_mat, g2, h2, sw def serialize(self) -> dict: @@ -883,8 +896,8 @@ def serialize(self) -> dict: "embeddings": repinit.embeddings.serialize(), "env_mat": EnvMat(repinit.rcut, repinit.rcut_smth).serialize(), "@variables": { - "davg": repinit["davg"], - "dstd": repinit["dstd"], + "davg": to_numpy_array(repinit["davg"]), + "dstd": to_numpy_array(repinit["dstd"]), }, } if repinit.tebd_input_mode in ["strip"]: @@ -896,8 +909,8 @@ def serialize(self) -> dict: "repformer_layers": [layer.serialize() for layer in repformers.layers], "env_mat": EnvMat(repformers.rcut, repformers.rcut_smth).serialize(), "@variables": { - "davg": repformers["davg"], - "dstd": repformers["dstd"], + "davg": to_numpy_array(repformers["davg"]), + "dstd": to_numpy_array(repformers["dstd"]), }, } data.update( @@ -913,8 +926,8 @@ def serialize(self) -> dict: repinit_three_body.rcut, repinit_three_body.rcut_smth ).serialize(), "@variables": { - "davg": repinit_three_body["davg"], - "dstd": repinit_three_body["dstd"], + "davg": to_numpy_array(repinit_three_body["davg"]), + "dstd": to_numpy_array(repinit_three_body["dstd"]), }, } if repinit_three_body.tebd_input_mode in ["strip"]: diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index ef79ecdd28..5422ff345e 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -5,12 +5,19 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( PRECISION_DICT, NativeOP, ) +from deepmd.dpmodel.array_api import ( + xp_take_along_axis, +) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.utils import ( EnvMat, PairExcludeMask, @@ -38,6 +45,28 @@ ) +def xp_transpose_01423(x): + xp = array_api_compat.array_namespace(x) + x_shape2 = x.shape[2] + x_shape3 = x.shape[3] + x_shape4 = x.shape[4] + x = xp.reshape(x, (x.shape[0], x.shape[1], x_shape2 * x_shape3, x_shape4)) + x = xp.matrix_transpose(x) + x = xp.reshape(x, (x.shape[0], x.shape[1], x_shape4, x_shape2, x_shape3)) + return x + + +def xp_transpose_01342(x): + xp = array_api_compat.array_namespace(x) + x_shape2 = x.shape[2] + x_shape3 = x.shape[3] + x_shape4 = x.shape[4] + x = xp.reshape(x, (x.shape[0], x.shape[1], x_shape2, x_shape3 * x_shape4)) + x = xp.matrix_transpose(x) + x = xp.reshape(x, (x.shape[0], x.shape[1], x_shape3, x_shape4, x_shape2)) + return x + + @DescriptorBlock.register("se_repformer") @DescriptorBlock.register("se_uni") class DescrptBlockRepformers(NativeOP, DescriptorBlock): @@ -360,8 +389,9 @@ def call( atype_embd_ext: Optional[np.ndarray] = None, mapping: Optional[np.ndarray] = None, ): + xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) - nlist = np.where(exclude_mask, nlist, -1) + nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1)) # nf x nloc x nnei x 4 dmatrix, diff, sw = self.env_mat.call( coord_ext, atype_ext, nlist, self.mean, self.stddev @@ -370,8 +400,8 @@ def call( # nf x nloc x nnei nlist_mask = nlist != -1 # nf x nloc x nnei - sw = sw.reshape(nf, nloc, nnei) - sw = np.where(nlist_mask, sw, 0.0) + sw = xp.reshape(sw, (nf, nloc, nnei)) + sw = xp.where(nlist_mask, sw, xp.zeros_like(sw)) # nf x nloc x tebd_dim atype_embd = atype_embd_ext[:, :nloc, :] assert list(atype_embd.shape) == [nf, nloc, self.g1_dim] @@ -379,22 +409,22 @@ def call( g1 = self.act(atype_embd) # nf x nloc x nnei x 1, nf x nloc x nnei x 3 if not self.direct_dist: - g2, h2 = np.split(dmatrix, [1], axis=-1) + g2, h2 = xp.split(dmatrix, [1], axis=-1) else: - g2, h2 = np.linalg.norm(diff, axis=-1, keepdims=True), diff + g2, h2 = xp.linalg.vector_norm(diff, axis=-1, keepdims=True), diff g2 = g2 / self.rcut h2 = h2 / self.rcut # nf x nloc x nnei x ng2 g2 = self.act(self.g2_embd(g2)) # set all padding positions to index of 0 # if a neighbor is real or not is indicated by nlist_mask - nlist[nlist == -1] = 0 + nlist = xp.where(nlist == -1, xp.zeros_like(nlist), nlist) # nf x nall x ng1 - mapping = np.tile(mapping.reshape(nf, -1, 1), (1, 1, self.g1_dim)) + mapping = xp.tile(xp.reshape(mapping, (nf, -1, 1)), (1, 1, self.g1_dim)) for idx, ll in enumerate(self.layers): # g1: nf x nloc x ng1 # g1_ext: nf x nall x ng1 - g1_ext = np.take_along_axis(g1, mapping, axis=1) + g1_ext = xp_take_along_axis(g1, mapping, axis=1) g1, g2, h2 = ll.call( g1_ext, g2, @@ -415,8 +445,9 @@ def call( use_sqrt_nnei=self.use_sqrt_nnei, ) # (nf x nloc) x ng2 x 3 - rot_mat = np.transpose(h2g2, (0, 1, 3, 2)) - return g1, g2, h2, rot_mat.reshape(nf, nloc, self.dim_emb, 3), sw + # rot_mat = xp.transpose(h2g2, (0, 1, 3, 2)) + rot_mat = xp.matrix_transpose(h2g2) + return g1, g2, h2, xp.reshape(rot_mat, (nf, nloc, self.dim_emb, 3)), sw def has_message_passing(self) -> bool: """Returns whether the descriptor block has message passing.""" @@ -426,6 +457,72 @@ def need_sorted_nlist_for_lower(self) -> bool: """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" return False + @classmethod + def deserialize(cls, data): + """Deserialize the descriptor block.""" + data = data.copy() + g2_embd = NativeLayer.deserialize(data.pop("g2_embd")) + layers = [RepformerLayer.deserialize(dd) for dd in data.pop("repformer_layers")] + env_mat = EnvMat.deserialize(data.pop("env_mat")) + variables = data.pop("@variables") + davg = variables["davg"] + dstd = variables["dstd"] + obj = cls(**data) + obj.g2_embd = g2_embd + obj.layers = layers + obj.env_mat = env_mat + obj.mean = davg + obj.stddev = dstd + return obj + + def serialize(self): + """Serialize the descriptor block.""" + return { + "rcut": self.rcut, + "rcut_smth": self.rcut_smth, + "sel": self.sel, + "ntypes": self.ntypes, + "nlayers": self.nlayers, + "g1_dim": self.g1_dim, + "g2_dim": self.g2_dim, + "axis_neuron": self.axis_neuron, + "direct_dist": self.direct_dist, + "update_g1_has_conv": self.update_g1_has_conv, + "update_g1_has_drrd": self.update_g1_has_drrd, + "update_g1_has_grrg": self.update_g1_has_grrg, + "update_g1_has_attn": self.update_g1_has_attn, + "update_g2_has_g1g1": self.update_g2_has_g1g1, + "update_g2_has_attn": self.update_g2_has_attn, + "update_h2": self.update_h2, + "attn1_hidden": self.attn1_hidden, + "attn1_nhead": self.attn1_nhead, + "attn2_hidden": self.attn2_hidden, + "attn2_nhead": self.attn2_nhead, + "attn2_has_gate": self.attn2_has_gate, + "activation_function": self.activation_function, + "update_style": self.update_style, + "update_residual": self.update_residual, + "update_residual_init": self.update_residual_init, + "set_davg_zero": self.set_davg_zero, + "smooth": self.smooth, + "exclude_types": self.exclude_types, + "env_protection": self.env_protection, + "precision": self.precision, + "trainable_ln": self.trainable_ln, + "use_sqrt_nnei": self.use_sqrt_nnei, + "g1_out_conv": self.g1_out_conv, + "g1_out_mlp": self.g1_out_mlp, + "ln_eps": self.ln_eps, + # variables + "g2_embd": self.g2_embd.serialize(), + "repformer_layers": [layer.serialize() for layer in self.layers], + "env_mat": self.env_mat.serialize(), + "@variables": { + "davg": to_numpy_array(self["davg"]), + "dstd": to_numpy_array(self["dstd"]), + }, + } + # translated by GPT and modified def get_residual( @@ -487,16 +584,17 @@ def _make_nei_g1( gg1: np.ndarray Neighbor-wise atomic invariant rep, with shape [nf, nloc, nnei, ng1]. """ + xp = array_api_compat.array_namespace(g1_ext, nlist) # nlist: nf x nloc x nnei nf, nloc, nnei = nlist.shape # g1_ext: nf x nall x ng1 ng1 = g1_ext.shape[-1] # index: nf x (nloc x nnei) x ng1 - index = np.tile(nlist.reshape(nf, nloc * nnei, 1), (1, 1, ng1)) + index = xp.tile(xp.reshape(nlist, (nf, nloc * nnei, 1)), (1, 1, ng1)) # gg1 : nf x (nloc x nnei) x ng1 - gg1 = np.take_along_axis(g1_ext, index, axis=1) + gg1 = xp_take_along_axis(g1_ext, index, axis=1) # gg1 : nf x nloc x nnei x ng1 - gg1 = gg1.reshape(nf, nloc, nnei, ng1) + gg1 = xp.reshape(gg1, (nf, nloc, nnei, ng1)) return gg1 @@ -514,7 +612,8 @@ def _apply_nlist_mask( nlist_mask Neighbor list mask, where zero means no neighbor, with shape [nf, nloc, nnei]. """ - masked_gg = np.where(nlist_mask[:, :, :, None], gg, 0.0) + xp = array_api_compat.array_namespace(gg, nlist_mask) + masked_gg = xp.where(nlist_mask[:, :, :, None], gg, xp.zeros_like(gg)) return masked_gg @@ -570,6 +669,7 @@ def _cal_hg( hg The transposed rotation matrix, with shape [nf, nloc, 3, ng]. """ + xp = array_api_compat.array_namespace(g, h, nlist_mask, sw) # g: nf x nloc x nnei x ng # h: nf x nloc x nnei x 3 # msk: nf x nloc x nnei @@ -580,21 +680,23 @@ def _cal_hg( if not smooth: # nf x nloc if not use_sqrt_nnei: - invnnei = 1.0 / (epsilon + np.sum(nlist_mask, axis=-1)) + invnnei = 1.0 / (epsilon + xp.sum(xp.astype(nlist_mask, g.dtype), axis=-1)) else: - invnnei = 1.0 / (epsilon + np.sqrt(np.sum(nlist_mask, axis=-1))) + invnnei = 1.0 / ( + epsilon + xp.sqrt(xp.sum(xp.astype(nlist_mask, g.dtype), axis=-1)) + ) # nf x nloc x 1 x 1 - invnnei = invnnei[:, :, np.newaxis, np.newaxis] + invnnei = invnnei[:, :, xp.newaxis, xp.newaxis] else: g = _apply_switch(g, sw) if not use_sqrt_nnei: - invnnei = (1.0 / float(nnei)) * np.ones((nf, nloc, 1, 1), dtype=g.dtype) + invnnei = (1.0 / float(nnei)) * xp.ones((nf, nloc, 1, 1), dtype=g.dtype) else: - invnnei = (1.0 / (float(nnei) ** 0.5)) * np.ones( + invnnei = (1.0 / (float(nnei) ** 0.5)) * xp.ones( (nf, nloc, 1, 1), dtype=g.dtype ) # nf x nloc x 3 x ng - hg = np.matmul(np.transpose(h, axes=(0, 1, 3, 2)), g) * invnnei + hg = xp.matmul(xp.matrix_transpose(h), g) * invnnei return hg @@ -614,14 +716,15 @@ def _cal_grrg(hg: np.ndarray, axis_neuron: int) -> np.ndarray: grrg Atomic invariant rep, with shape [nf, nloc, (axis_neuron * ng)]. """ + xp = array_api_compat.array_namespace(hg) # nf x nloc x 3 x ng nf, nloc, _, ng = hg.shape # nf x nloc x 3 x axis - hgm = np.split(hg, [axis_neuron], axis=-1)[0] + hgm = hg[..., :axis_neuron] # nf x nloc x axis_neuron x ng - grrg = np.matmul(np.transpose(hgm, axes=(0, 1, 3, 2)), hg) / (3.0**1) + grrg = xp.matmul(xp.matrix_transpose(hgm), hg) / (3.0**1) # nf x nloc x (axis_neuron * ng) - grrg = grrg.reshape(nf, nloc, axis_neuron * ng) + grrg = xp.reshape(grrg, (nf, nloc, axis_neuron * ng)) return grrg @@ -718,6 +821,7 @@ def call( nlist_mask: np.ndarray, # nf x nloc x nnei sw: np.ndarray, # nf x nloc x nnei ) -> np.ndarray: + xp = array_api_compat.array_namespace(g2, h2, nlist_mask, sw) ( nf, nloc, @@ -726,43 +830,47 @@ def call( ) = g2.shape nd, nh = self.hidden_dim, self.head_num # nf x nloc x nnei x nd x (nh x 2) - g2qk = self.mapqk(g2).reshape(nf, nloc, nnei, nd, nh * 2) + g2qk = self.mapqk(g2) + g2qk = xp.reshape(g2qk, (nf, nloc, nnei, nd, nh * 2)) # nf x nloc x (nh x 2) x nnei x nd - g2qk = np.transpose(g2qk, (0, 1, 4, 2, 3)) + # g2qk = xp.transpose(g2qk, (0, 1, 4, 2, 3)) + g2qk = xp_transpose_01423(g2qk) # nf x nloc x nh x nnei x nd - g2q, g2k = np.split(g2qk, [nh], axis=2) + # g2q, g2k = xp.split(g2qk, [nh], axis=2) + g2q = g2qk[:, :, :nh, :, :] + g2k = g2qk[:, :, nh:, :, :] # g2q = np.linalg.norm(g2q, axis=-1) # g2k = np.linalg.norm(g2k, axis=-1) # nf x nloc x nh x nnei x nnei - attnw = np.matmul(g2q, np.transpose(g2k, axes=(0, 1, 2, 4, 3))) / nd**0.5 + attnw = xp.matmul(g2q, xp.matrix_transpose(g2k)) / nd**0.5 if self.has_gate: - gate = np.matmul(h2, np.transpose(h2, axes=(0, 1, 3, 2))).reshape( - nf, nloc, 1, nnei, nnei - ) + gate = xp.matmul(h2, xp.matrix_transpose(h2)) + gate = xp.reshape(gate, (nf, nloc, 1, nnei, nnei)) attnw = attnw * gate # mask the attenmap, nf x nloc x 1 x 1 x nnei - attnw_mask = ~np.expand_dims(np.expand_dims(nlist_mask, axis=2), axis=2) + attnw_mask = ~xp.expand_dims(xp.expand_dims(nlist_mask, axis=2), axis=2) # mask the attenmap, nf x nloc x 1 x nnei x 1 - attnw_mask_c = ~np.expand_dims(np.expand_dims(nlist_mask, axis=2), axis=-1) + attnw_mask_c = ~xp.expand_dims(xp.expand_dims(nlist_mask, axis=2), axis=-1) if self.smooth: attnw = (attnw + self.attnw_shift) * sw[:, :, None, :, None] * sw[ :, :, None, None, : ] - self.attnw_shift else: - attnw = np.where(attnw_mask, -np.inf, attnw) + attnw = xp.where(attnw_mask, xp.full_like(attnw, -xp.inf), attnw) attnw = np_softmax(attnw, axis=-1) - attnw = np.where(attnw_mask, 0.0, attnw) + attnw = xp.where(attnw_mask, xp.zeros_like(attnw), attnw) # nf x nloc x nh x nnei x nnei - attnw = np.where(attnw_mask_c, 0.0, attnw) + attnw = xp.where(attnw_mask_c, xp.zeros_like(attnw), attnw) if self.smooth: attnw = attnw * sw[:, :, None, :, None] * sw[:, :, None, None, :] # nf x nloc x nnei x nnei - h2h2t = np.matmul(h2, np.transpose(h2, axes=(0, 1, 3, 2))) / 3.0**0.5 + h2h2t = xp.matmul(h2, xp.matrix_transpose(h2)) / 3.0**0.5 # nf x nloc x nh x nnei x nnei ret = attnw * h2h2t[:, :, None, :, :] # ret = np.exp(g2qk - np.max(g2qk, axis=-1, keepdims=True)) # nf x nloc x nnei x nnei x nh - ret = np.transpose(ret, (0, 1, 3, 4, 2)) + # ret = xp.transpose(ret, (0, 1, 3, 4, 2)) + ret = xp_transpose_01342(ret) return ret def serialize(self) -> dict: @@ -835,19 +943,22 @@ def call( AA: np.ndarray, # nf x nloc x nnei x nnei x nh g2: np.ndarray, # nf x nloc x nnei x ng2 ) -> np.ndarray: + xp = array_api_compat.array_namespace(AA, g2) nf, nloc, nnei, ng2 = g2.shape nh = self.head_num # nf x nloc x nnei x ng2 x nh - g2v = self.mapv(g2).reshape(nf, nloc, nnei, ng2, nh) + g2v = self.mapv(g2) + g2v = xp.reshape(g2v, (nf, nloc, nnei, ng2, nh)) # nf x nloc x nh x nnei x ng2 - g2v = np.transpose(g2v, (0, 1, 4, 2, 3)) + g2v = xp_transpose_01423(g2v) # g2v = np.linalg.norm(g2v, axis=-1) # nf x nloc x nh x nnei x nnei - AA = np.transpose(AA, (0, 1, 4, 2, 3)) + AA = xp_transpose_01423(AA) # nf x nloc x nh x nnei x ng2 - ret = np.matmul(AA, g2v) + ret = xp.matmul(AA, g2v) # nf x nloc x nnei x ng2 x nh - ret = np.transpose(ret, (0, 1, 3, 4, 2)).reshape(nf, nloc, nnei, (ng2 * nh)) + ret = xp_transpose_01342(ret) + ret = xp.reshape(ret, (nf, nloc, nnei, (ng2 * nh))) # nf x nloc x nnei x ng2 return self.head_map(ret) @@ -910,19 +1021,21 @@ def call( AA: np.ndarray, # nf x nloc x nnei x nnei x nh h2: np.ndarray, # nf x nloc x nnei x 3 ) -> np.ndarray: + xp = array_api_compat.array_namespace(AA, h2) nf, nloc, nnei, _ = h2.shape nh = self.head_num # nf x nloc x nh x nnei x nnei - AA = np.transpose(AA, (0, 1, 4, 2, 3)) - h2m = np.expand_dims(h2, axis=2) + AA = xp_transpose_01423(AA) + h2m = xp.expand_dims(h2, axis=2) # nf x nloc x nh x nnei x 3 - h2m = np.tile(h2m, (1, 1, nh, 1, 1)) + h2m = xp.tile(h2m, (1, 1, nh, 1, 1)) # nf x nloc x nh x nnei x 3 - ret = np.matmul(AA, h2m) + ret = xp.matmul(AA, h2m) # nf x nloc x nnei x 3 x nh - ret = np.transpose(ret, (0, 1, 3, 4, 2)).reshape(nf, nloc, nnei, 3, nh) + ret = xp_transpose_01342(ret) + ret = xp.reshape(ret, (nf, nloc, nnei, 3, nh)) # nf x nloc x nnei x 3 - return np.squeeze(self.head_map(ret), axis=-1) + return xp.squeeze(self.head_map(ret), axis=-1) def serialize(self) -> dict: """Serialize the networks to a dict. @@ -1005,49 +1118,49 @@ def call( nlist_mask: np.ndarray, # nf x nloc x nnei sw: np.ndarray, # nf x nloc x nnei ) -> np.ndarray: + xp = array_api_compat.array_namespace(g1, gg1, nlist_mask, sw) nf, nloc, nnei = nlist_mask.shape ni, nd, nh = self.input_dim, self.hidden_dim, self.head_num assert ni == g1.shape[-1] assert ni == gg1.shape[-1] # nf x nloc x nd x nh - g1q = self.mapq(g1).reshape(nf, nloc, nd, nh) + g1q = self.mapq(g1) + g1q = xp.reshape(g1q, (nf, nloc, nd, nh)) # nf x nloc x nh x nd - g1q = np.transpose(g1q, (0, 1, 3, 2)) + g1q = xp.matrix_transpose(g1q) # nf x nloc x nnei x (nd+ni) x nh - gg1kv = self.mapkv(gg1).reshape(nf, nloc, nnei, nd + ni, nh) - gg1kv = np.transpose(gg1kv, (0, 1, 4, 2, 3)) + gg1kv = self.mapkv(gg1) + gg1kv = xp.reshape(gg1kv, (nf, nloc, nnei, nd + ni, nh)) + gg1kv = xp_transpose_01423(gg1kv) # nf x nloc x nh x nnei x nd, nf x nloc x nh x nnei x ng1 - gg1k, gg1v = np.split(gg1kv, [nd], axis=-1) + # gg1k, gg1v = xp.split(gg1kv, [nd], axis=-1) + gg1k = gg1kv[:, :, :, :, :nd] + gg1v = gg1kv[:, :, :, :, nd:] # nf x nloc x nh x 1 x nnei attnw = ( - np.matmul( - np.expand_dims(g1q, axis=-2), np.transpose(gg1k, axes=(0, 1, 2, 4, 3)) - ) - / nd**0.5 + xp.matmul(xp.expand_dims(g1q, axis=-2), xp.matrix_transpose(gg1k)) / nd**0.5 ) # nf x nloc x nh x nnei - attnw = np.squeeze(attnw, axis=-2) + attnw = xp.squeeze(attnw, axis=-2) # mask the attenmap, nf x nloc x 1 x nnei - attnw_mask = ~np.expand_dims(nlist_mask, axis=-2) + attnw_mask = ~xp.expand_dims(nlist_mask, axis=-2) # nf x nloc x nh x nnei if self.smooth: - attnw = (attnw + self.attnw_shift) * np.expand_dims( + attnw = (attnw + self.attnw_shift) * xp.expand_dims( sw, axis=-2 ) - self.attnw_shift else: - attnw = np.where(attnw_mask, -np.inf, attnw) + attnw = xp.where(attnw_mask, xp.full_like(attnw, -xp.inf), attnw) attnw = np_softmax(attnw, axis=-1) - attnw = np.where(attnw_mask, 0.0, attnw) + attnw = xp.where(attnw_mask, xp.zeros_like(attnw), attnw) if self.smooth: - attnw = attnw * np.expand_dims(sw, axis=-2) + attnw = attnw * xp.expand_dims(sw, axis=-2) # nf x nloc x nh x ng1 - ret = ( - np.matmul(np.expand_dims(attnw, axis=-2), gg1v) - .squeeze(-2) - .reshape(nf, nloc, nh * ni) - ) + ret = xp.matmul(xp.expand_dims(attnw, axis=-2), gg1v) + ret = xp.squeeze(ret, axis=-2) + ret = xp.reshape(ret, (nf, nloc, nh * ni)) # nf x nloc x ng1 ret = self.head_map(ret) return ret @@ -1178,12 +1291,12 @@ def __init__( ], "'update_residual_init' only support 'norm' or 'const'!" self.update_residual = update_residual self.update_residual_init = update_residual_init - self.g1_residual = [] - self.g2_residual = [] - self.h2_residual = [] + g1_residual = [] + g2_residual = [] + h2_residual = [] if self.update_style == "res_residual": - self.g1_residual.append( + g1_residual.append( get_residual( g1_dim, self.update_residual, @@ -1217,7 +1330,7 @@ def __init__( seed=child_seed(seed, 2), ) if self.update_style == "res_residual": - self.g2_residual.append( + g2_residual.append( get_residual( g2_dim, self.update_residual, @@ -1234,7 +1347,7 @@ def __init__( seed=child_seed(seed, 15), ) if self.update_style == "res_residual": - self.g1_residual.append( + g1_residual.append( get_residual( g1_dim, self.update_residual, @@ -1263,7 +1376,7 @@ def __init__( seed=child_seed(seed, 4), ) if self.update_style == "res_residual": - self.g1_residual.append( + g1_residual.append( get_residual( g1_dim, self.update_residual, @@ -1281,7 +1394,7 @@ def __init__( seed=child_seed(seed, 5), ) if self.update_style == "res_residual": - self.g2_residual.append( + g2_residual.append( get_residual( g2_dim, self.update_residual, @@ -1312,7 +1425,7 @@ def __init__( seed=child_seed(seed, 9), ) if self.update_style == "res_residual": - self.g2_residual.append( + g2_residual.append( get_residual( g2_dim, self.update_residual, @@ -1327,7 +1440,7 @@ def __init__( g2_dim, attn2_nhead, precision=precision, seed=child_seed(seed, 11) ) if self.update_style == "res_residual": - self.h2_residual.append( + h2_residual.append( get_residual( 1, self.update_residual, @@ -1346,7 +1459,7 @@ def __init__( seed=child_seed(seed, 13), ) if self.update_style == "res_residual": - self.g1_residual.append( + g1_residual.append( get_residual( g1_dim, self.update_residual, @@ -1356,6 +1469,10 @@ def __init__( ) ) + self.g1_residual = g1_residual + self.g2_residual = g2_residual + self.h2_residual = h2_residual + def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: ret = g1d if not self.g1_out_mlp else 0 if self.update_g1_has_grrg: @@ -1408,35 +1525,40 @@ def _update_g1_conv( The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, and remains 0 beyond rcut, with shape nf x nloc x nnei. """ + xp = array_api_compat.array_namespace(gg1, g2, nlist_mask, sw) assert self.proj_g1g2 is not None nf, nloc, nnei, _ = g2.shape ng1 = gg1.shape[-1] ng2 = g2.shape[-1] if not self.g1_out_conv: # gg1 : nf x nloc x nnei x ng2 - gg1 = self.proj_g1g2(gg1).reshape(nf, nloc, nnei, ng2) + gg1 = self.proj_g1g2(gg1) + gg1 = xp.reshape(gg1, (nf, nloc, nnei, ng2)) else: # gg1 : nf x nloc x nnei x ng1 - gg1 = gg1.reshape(nf, nloc, nnei, ng1) + gg1 = xp.reshape(gg1, (nf, nloc, nnei, ng1)) # nf x nloc x nnei x ng2/ng1 gg1 = _apply_nlist_mask(gg1, nlist_mask) if not self.smooth: # normalized by number of neighbors, not smooth # nf x nloc - invnnei = 1.0 / (self.epsilon + np.sum(nlist_mask, axis=-1)) + invnnei = 1.0 / ( + self.epsilon + xp.sum(xp.astype(nlist_mask, gg1.dtype), axis=-1) + ) # nf x nloc x 1 - invnnei = invnnei[:, :, np.newaxis] + invnnei = invnnei[:, :, xp.newaxis] else: gg1 = _apply_switch(gg1, sw) - invnnei = (1.0 / float(nnei)) * np.ones((nf, nloc, 1), dtype=gg1.dtype) + invnnei = (1.0 / float(nnei)) * xp.ones((nf, nloc, 1), dtype=gg1.dtype) if not self.g1_out_conv: # nf x nloc x ng2 - g1_11 = np.sum(g2 * gg1, axis=2) * invnnei + g1_11 = xp.sum(g2 * gg1, axis=2) * invnnei else: # nf x nloc x ng1 - g2 = self.proj_g1g2(g2).reshape(nf, nloc, nnei, ng1) + g2 = self.proj_g1g2(g2) + g2 = xp.reshape(g2, (nf, nloc, nnei, ng1)) # nb x nloc x ng1 - g1_11 = np.sum(g2 * gg1, axis=2) * invnnei + g1_11 = xp.sum(g2 * gg1, axis=2) * invnnei return g1_11 def _update_g2_g1g1( @@ -1461,7 +1583,8 @@ def _update_g2_g1g1( The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, and remains 0 beyond rcut, with shape nf x nloc x nnei. """ - ret = np.expand_dims(g1, axis=-2) * gg1 + xp = array_api_compat.array_namespace(g1, gg1, nlist_mask, sw) + ret = xp.expand_dims(g1, axis=-2) * gg1 # nf x nloc x nnei x ng1 ret = _apply_nlist_mask(ret, nlist_mask) if self.smooth: @@ -1493,6 +1616,7 @@ def call( g2: nf x nloc x nnei x ng2 updated pair-atom channel, invariant h2: nf x nloc x nnei x 3 updated pair-atom channel, equivariant """ + xp = array_api_compat.array_namespace(g1_ext, g2, h2, nlist, nlist_mask, sw) cal_gg1 = ( self.update_g1_has_drrd or self.update_g1_has_conv @@ -1502,7 +1626,8 @@ def call( nf, nloc, nnei, _ = g2.shape nall = g1_ext.shape[1] - g1, _ = np.split(g1_ext, [nloc], axis=1) + # g1, _ = xp.split(g1_ext, [nloc], axis=1) + g1 = g1_ext[:, :nloc, :] assert (nf, nloc) == g1.shape[:2] assert (nf, nloc, nnei) == h2.shape[:3] @@ -1592,7 +1717,7 @@ def call( # nf x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] # conv grrg drrd - g1_1 = self.act(self.linear1(np.concatenate(g1_mlp, axis=-1))) + g1_1 = self.act(self.linear1(xp.concat(g1_mlp, axis=-1))) g1_update.append(g1_1) if self.update_g1_has_attn: @@ -1752,9 +1877,9 @@ def serialize(self) -> dict: if self.update_style == "res_residual": data.update( { - "g1_residual": self.g1_residual, - "g2_residual": self.g2_residual, - "h2_residual": self.h2_residual, + "g1_residual": [to_numpy_array(aa) for aa in self.g1_residual], + "g2_residual": [to_numpy_array(aa) for aa in self.g2_residual], + "h2_residual": [to_numpy_array(aa) for aa in self.h2_residual], } ) return data diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index b827032588..7b3b25df36 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -215,30 +215,31 @@ def build_multiple_neighbor_list( value being the corresponding nlist. """ + xp = array_api_compat.array_namespace(coord, nlist) assert len(rcuts) == len(nsels) if len(rcuts) == 0: return {} nb, nloc, nsel = nlist.shape if nsel < nsels[-1]: - pad = -1 * np.ones((nb, nloc, nsels[-1] - nsel), dtype=nlist.dtype) - nlist = np.concatenate([nlist, pad], axis=-1) + pad = -1 * xp.ones((nb, nloc, nsels[-1] - nsel), dtype=nlist.dtype) + nlist = xp.concat([nlist, pad], axis=-1) nsel = nsels[-1] - coord1 = coord.reshape(nb, -1, 3) + coord1 = xp.reshape(coord, (nb, -1, 3)) nall = coord1.shape[1] coord0 = coord1[:, :nloc, :] nlist_mask = nlist == -1 - tnlist_0 = nlist.copy() - tnlist_0[nlist_mask] = 0 - index = np.tile(tnlist_0.reshape(nb, nloc * nsel, 1), [1, 1, 3]) - coord2 = np.take_along_axis(coord1, index, axis=1).reshape(nb, nloc, nsel, 3) + tnlist_0 = xp.where(nlist_mask, xp.zeros_like(nlist), nlist) + index = xp.tile(xp.reshape(tnlist_0, (nb, nloc * nsel, 1)), (1, 1, 3)) + coord2 = xp_take_along_axis(coord1, index, axis=1) + coord2 = xp.reshape(coord2, (nb, nloc, nsel, 3)) diff = coord2 - coord0[:, :, None, :] - rr = np.linalg.norm(diff, axis=-1) - rr = np.where(nlist_mask, float("inf"), rr) + rr = xp.linalg.vector_norm(diff, axis=-1) + rr = xp.where(nlist_mask, xp.full_like(rr, float("inf")), rr) nlist0 = nlist ret = {} for rc, ns in zip(rcuts[::-1], nsels[::-1]): - tnlist_1 = np.copy(nlist0[:, :, :ns]) - tnlist_1[rr[:, :, :ns] > rc] = -1 + tnlist_1 = nlist0[:, :, :ns] + tnlist_1 = xp.where(rr[:, :, :ns] > rc, xp.full_like(tnlist_1, -1), tnlist_1) ret[get_multiple_nlist_key(rc, ns)] = tnlist_1 return ret diff --git a/deepmd/jax/descriptor/dpa2.py b/deepmd/jax/descriptor/dpa2.py new file mode 100644 index 0000000000..8eea324b41 --- /dev/null +++ b/deepmd/jax/descriptor/dpa2.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2DP +from deepmd.dpmodel.utils.network import Identity as IdentityDP +from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP +from deepmd.jax.common import ( + ArrayAPIVariable, + flax_module, + to_jax_array, +) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.jax.descriptor.dpa1 import ( + DescrptBlockSeAtten, +) +from deepmd.jax.descriptor.repformers import ( + DescrptBlockRepformers, +) +from deepmd.jax.descriptor.se_t_tebd import ( + DescrptBlockSeTTebd, +) +from deepmd.jax.utils.network import ( + NativeLayer, +) +from deepmd.jax.utils.type_embed import ( + TypeEmbedNet, +) + + +@BaseDescriptor.register("dpa2") +@flax_module +class DescrptDPA2(DescrptDPA2DP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) + elif name in {"repinit"}: + value = DescrptBlockSeAtten.deserialize(value.serialize()) + elif name in {"repinit_three_body"}: + if value is not None: + value = DescrptBlockSeTTebd.deserialize(value.serialize()) + elif name in {"repformers"}: + value = DescrptBlockRepformers.deserialize(value.serialize()) + elif name in {"type_embedding"}: + value = TypeEmbedNet.deserialize(value.serialize()) + elif name in {"g1_shape_tranform", "tebd_transform"}: + if value is None: + pass + elif isinstance(value, NativeLayerDP): + value = NativeLayer.deserialize(value.serialize()) + elif isinstance(value, IdentityDP): + # IdentityDP doesn't contain any value - it's good to go + pass + else: + raise ValueError(f"Unknown layer type: {type(value)}") + return super().__setattr__(name, value) diff --git a/deepmd/jax/descriptor/repformers.py b/deepmd/jax/descriptor/repformers.py new file mode 100644 index 0000000000..77ca4a9a6b --- /dev/null +++ b/deepmd/jax/descriptor/repformers.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.repformers import ( + Atten2EquiVarApply as Atten2EquiVarApplyDP, +) +from deepmd.dpmodel.descriptor.repformers import Atten2Map as Atten2MapDP +from deepmd.dpmodel.descriptor.repformers import ( + Atten2MultiHeadApply as Atten2MultiHeadApplyDP, +) +from deepmd.dpmodel.descriptor.repformers import ( + DescrptBlockRepformers as DescrptBlockRepformersDP, +) +from deepmd.dpmodel.descriptor.repformers import LocalAtten as LocalAttenDP +from deepmd.dpmodel.descriptor.repformers import RepformerLayer as RepformerLayerDP +from deepmd.jax.common import ( + ArrayAPIVariable, + flax_module, + to_jax_array, +) +from deepmd.jax.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.jax.utils.network import ( + LayerNorm, + NativeLayer, +) + + +@flax_module +class DescrptBlockRepformers(DescrptBlockRepformersDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) + elif name in {"layers"}: + value = [RepformerLayer.deserialize(layer.serialize()) for layer in value] + elif name == "g2_embd": + value = NativeLayer.deserialize(value.serialize()) + elif name == "env_mat": + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) + + +@flax_module +class Atten2Map(Atten2MapDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mapqk"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +@flax_module +class Atten2MultiHeadApply(Atten2MultiHeadApplyDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mapv", "head_map"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +@flax_module +class Atten2EquiVarApply(Atten2EquiVarApplyDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"head_map"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +@flax_module +class LocalAtten(LocalAttenDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mapq", "mapkv", "head_map"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +@flax_module +class RepformerLayer(RepformerLayerDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"linear1", "linear2", "g1_self_mlp", "proj_g1g2", "proj_g1g1g2"}: + if value is not None: + value = NativeLayer.deserialize(value.serialize()) + elif name in {"g1_residual", "g2_residual", "h2_residual"}: + value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value] + elif name in {"attn2g_map"}: + if value is not None: + value = Atten2Map.deserialize(value.serialize()) + elif name in {"attn2_mh_apply"}: + if value is not None: + value = Atten2MultiHeadApply.deserialize(value.serialize()) + elif name in {"attn2_lm"}: + if value is not None: + value = LayerNorm.deserialize(value.serialize()) + elif name in {"attn2_ev_apply"}: + if value is not None: + value = Atten2EquiVarApply.deserialize(value.serialize()) + elif name in {"loc_attn"}: + if value is not None: + value = LocalAtten.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/doc/model/dpa2.md b/doc/model/dpa2.md index 24ce5222e9..27ffc1b14d 100644 --- a/doc/model/dpa2.md +++ b/doc/model/dpa2.md @@ -1,7 +1,7 @@ -# Descriptor DPA-2 {{ pytorch_icon }} {{ dpmodel_icon }} +# Descriptor DPA-2 {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }} :::{note} -**Supported backends**: PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }} +**Supported backends**: PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }} ::: The DPA-2 model implementation. See https://arxiv.org/abs/2312.15492 for more details. diff --git a/source/tests/array_api_strict/descriptor/dpa2.py b/source/tests/array_api_strict/descriptor/dpa2.py new file mode 100644 index 0000000000..a510c6b461 --- /dev/null +++ b/source/tests/array_api_strict/descriptor/dpa2.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2DP +from deepmd.dpmodel.utils.network import Identity as IdentityDP +from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP + +from ..common import ( + to_array_api_strict_array, +) +from ..utils.network import ( + NativeLayer, +) +from ..utils.type_embed import ( + TypeEmbedNet, +) +from .base_descriptor import ( + BaseDescriptor, +) +from .dpa1 import ( + DescrptBlockSeAtten, +) +from .repformers import ( + DescrptBlockRepformers, +) +from .se_t_tebd import ( + DescrptBlockSeTTebd, +) + + +@BaseDescriptor.register("dpa2") +class DescrptDPA2(DescrptDPA2DP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_array_api_strict_array(value) + elif name in {"repinit"}: + value = DescrptBlockSeAtten.deserialize(value.serialize()) + elif name in {"repinit_three_body"}: + if value is not None: + value = DescrptBlockSeTTebd.deserialize(value.serialize()) + elif name in {"repformers"}: + value = DescrptBlockRepformers.deserialize(value.serialize()) + elif name in {"type_embedding"}: + value = TypeEmbedNet.deserialize(value.serialize()) + elif name in {"g1_shape_tranform", "tebd_transform"}: + if value is None: + pass + elif isinstance(value, NativeLayerDP): + value = NativeLayer.deserialize(value.serialize()) + elif isinstance(value, IdentityDP): + # IdentityDP doesn't contain any value - it's good to go + pass + else: + raise ValueError(f"Unknown layer type: {type(value)}") + return super().__setattr__(name, value) diff --git a/source/tests/array_api_strict/descriptor/repformers.py b/source/tests/array_api_strict/descriptor/repformers.py new file mode 100644 index 0000000000..ff65ff849f --- /dev/null +++ b/source/tests/array_api_strict/descriptor/repformers.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.repformers import ( + Atten2EquiVarApply as Atten2EquiVarApplyDP, +) +from deepmd.dpmodel.descriptor.repformers import Atten2Map as Atten2MapDP +from deepmd.dpmodel.descriptor.repformers import ( + Atten2MultiHeadApply as Atten2MultiHeadApplyDP, +) +from deepmd.dpmodel.descriptor.repformers import ( + DescrptBlockRepformers as DescrptBlockRepformersDP, +) +from deepmd.dpmodel.descriptor.repformers import LocalAtten as LocalAttenDP +from deepmd.dpmodel.descriptor.repformers import RepformerLayer as RepformerLayerDP + +from ..common import ( + to_array_api_strict_array, +) +from ..utils.exclude_mask import ( + PairExcludeMask, +) +from ..utils.network import ( + LayerNorm, + NativeLayer, +) + + +class DescrptBlockRepformers(DescrptBlockRepformersDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_array_api_strict_array(value) + elif name in {"layers"}: + value = [RepformerLayer.deserialize(layer.serialize()) for layer in value] + elif name == "g2_embd": + value = NativeLayer.deserialize(value.serialize()) + elif name == "env_mat": + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) + + +class Atten2Map(Atten2MapDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mapqk"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +class Atten2MultiHeadApply(Atten2MultiHeadApplyDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mapv", "head_map"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +class Atten2EquiVarApply(Atten2EquiVarApplyDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"head_map"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +class LocalAtten(LocalAttenDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mapq", "mapkv", "head_map"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +class RepformerLayer(RepformerLayerDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"linear1", "linear2", "g1_self_mlp", "proj_g1g2", "proj_g1g1g2"}: + if value is not None: + value = NativeLayer.deserialize(value.serialize()) + elif name in {"g1_residual", "g2_residual", "h2_residual"}: + value = [to_array_api_strict_array(vv) for vv in value] + elif name in {"attn2g_map"}: + if value is not None: + value = Atten2Map.deserialize(value.serialize()) + elif name in {"attn2_mh_apply"}: + if value is not None: + value = Atten2MultiHeadApply.deserialize(value.serialize()) + elif name in {"attn2_lm"}: + if value is not None: + value = LayerNorm.deserialize(value.serialize()) + elif name in {"attn2_ev_apply"}: + if value is not None: + value = Atten2EquiVarApply.deserialize(value.serialize()) + elif name in {"loc_attn"}: + if value is not None: + value = LocalAtten.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/source/tests/consistent/descriptor/test_dpa2.py b/source/tests/consistent/descriptor/test_dpa2.py index 53f9ce4200..17c55db368 100644 --- a/source/tests/consistent/descriptor/test_dpa2.py +++ b/source/tests/consistent/descriptor/test_dpa2.py @@ -15,6 +15,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, CommonTest, parameterized, @@ -28,6 +30,15 @@ else: DescrptDPA2PT = None +if INSTALLED_JAX: + from deepmd.jax.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2JAX +else: + DescrptDPA2JAX = None +if INSTALLED_ARRAY_API_STRICT: + from ...array_api_strict.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2Strict +else: + DescrptDPA2Strict = None + # not implemented DescrptDPA2TF = None @@ -269,9 +280,14 @@ def skip_tf(self) -> bool: ) = self.param return True + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + tf_class = DescrptDPA2TF dp_class = DescrptDPA2DP pt_class = DescrptDPA2PT + jax_class = DescrptDPA2JAX + array_api_strict_class = DescrptDPA2Strict args = descrpt_dpa2_args().append(Argument("ntypes", int, optional=False)) def setUp(self): @@ -367,6 +383,26 @@ def eval_pt(self, pt_obj: Any) -> Any: mixed_types=True, ) + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_descriptor( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return self.eval_array_api_strict_descriptor( + array_api_strict_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret[0],)