From 737f7c8bb77a1a32f76f1f2d72d7099a95db2fc7 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 31 Oct 2024 09:11:43 -0400 Subject: [PATCH] feat(jax/array-api): hybrid descriptor (#4275) ## Summary by CodeRabbit - **New Features** - Introduced support for the JAX backend in the hybrid descriptor framework. - Added a new `DescrptHybrid` class with specialized attribute handling. - Enhanced testing framework to support additional backends, including JAX and strict array API. - **Bug Fixes** - Improved attribute handling in multiple descriptor classes to ensure proper deserialization and registration. - **Documentation** - Updated documentation to reflect the addition of JAX as a supported backend for hybrid descriptors. --------- Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/hybrid.py | 13 ++++--- deepmd/jax/descriptor/__init__.py | 4 +++ deepmd/jax/descriptor/hybrid.py | 26 ++++++++++++++ doc/model/train-hybrid.md | 4 +-- .../array_api_strict/descriptor/__init__.py | 19 ++++++++++ .../descriptor/base_descriptor.py | 11 ++++++ .../tests/array_api_strict/descriptor/dpa1.py | 5 +++ .../array_api_strict/descriptor/hybrid.py | 24 +++++++++++++ .../array_api_strict/descriptor/se_e2_a.py | 5 +++ .../array_api_strict/descriptor/se_e2_r.py | 5 +++ .../consistent/descriptor/test_hybrid.py | 35 +++++++++++++++++++ 11 files changed, 144 insertions(+), 7 deletions(-) create mode 100644 deepmd/jax/descriptor/hybrid.py create mode 100644 source/tests/array_api_strict/descriptor/base_descriptor.py create mode 100644 source/tests/array_api_strict/descriptor/hybrid.py diff --git a/deepmd/dpmodel/descriptor/hybrid.py b/deepmd/dpmodel/descriptor/hybrid.py index 4eb14f29cf..0d89902e4a 100644 --- a/deepmd/dpmodel/descriptor/hybrid.py +++ b/deepmd/dpmodel/descriptor/hybrid.py @@ -6,6 +6,7 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel.common import ( @@ -66,7 +67,7 @@ def __init__( ), f"number of atom types in {ii}th descriptor {self.descrpt_list[0].__class__.__name__} does not match others" # if hybrid sel is larger than sub sel, the nlist needs to be cut for each type hybrid_sel = self.get_sel() - self.nlist_cut_idx: list[np.ndarray] = [] + nlist_cut_idx: list[np.ndarray] = [] if self.mixed_types() and not all( descrpt.mixed_types() for descrpt in self.descrpt_list ): @@ -92,7 +93,8 @@ def __init__( cut_idx = np.concatenate( [range(ss, ee) for ss, ee in zip(start_idx, end_idx)] ) - self.nlist_cut_idx.append(cut_idx) + nlist_cut_idx.append(cut_idx) + self.nlist_cut_idx = nlist_cut_idx def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -242,6 +244,7 @@ def call( sw The smooth switch function. """ + xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) out_descriptor = [] out_gr = [] out_g2 = None @@ -258,7 +261,7 @@ def call( for descrpt, nci in zip(self.descrpt_list, self.nlist_cut_idx): # cut the nlist to the correct length if self.mixed_types() == descrpt.mixed_types(): - nl = nlist[:, :, nci] + nl = xp.take(nlist, nci, axis=2) else: # mixed_types is True, but descrpt.mixed_types is False assert nl_distinguish_types is not None @@ -268,8 +271,8 @@ def call( if gr is not None: out_gr.append(gr) - out_descriptor = np.concatenate(out_descriptor, axis=-1) - out_gr = np.concatenate(out_gr, axis=-2) if out_gr else None + out_descriptor = xp.concat(out_descriptor, axis=-1) + out_gr = xp.concat(out_gr, axis=-2) if out_gr else None return out_descriptor, out_gr, out_g2, out_h2, out_sw @classmethod diff --git a/deepmd/jax/descriptor/__init__.py b/deepmd/jax/descriptor/__init__.py index 3ed096f9c1..cabee5a189 100644 --- a/deepmd/jax/descriptor/__init__.py +++ b/deepmd/jax/descriptor/__init__.py @@ -2,6 +2,9 @@ from deepmd.jax.descriptor.dpa1 import ( DescrptDPA1, ) +from deepmd.jax.descriptor.hybrid import ( + DescrptHybrid, +) from deepmd.jax.descriptor.se_e2_a import ( DescrptSeA, ) @@ -13,4 +16,5 @@ "DescrptSeA", "DescrptSeR", "DescrptDPA1", + "DescrptHybrid", ] diff --git a/deepmd/jax/descriptor/hybrid.py b/deepmd/jax/descriptor/hybrid.py new file mode 100644 index 0000000000..20fc5f838b --- /dev/null +++ b/deepmd/jax/descriptor/hybrid.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP +from deepmd.jax.common import ( + ArrayAPIVariable, + flax_module, + to_jax_array, +) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("hybrid") +@flax_module +class DescrptHybrid(DescrptHybridDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"nlist_cut_idx"}: + value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value] + elif name in {"descrpt_list"}: + value = [BaseDescriptor.deserialize(vv.serialize()) for vv in value] + + return super().__setattr__(name, value) diff --git a/doc/model/train-hybrid.md b/doc/model/train-hybrid.md index 1219d208a7..da3b40487b 100644 --- a/doc/model/train-hybrid.md +++ b/doc/model/train-hybrid.md @@ -1,7 +1,7 @@ -# Descriptor `"hybrid"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }} +# Descriptor `"hybrid"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }} ::: This descriptor hybridizes multiple descriptors to form a new descriptor. For example, we have a list of descriptors denoted by $\mathcal D_1$, $\mathcal D_2$, ..., $\mathcal D_N$, the hybrid descriptor this the concatenation of the list, i.e. $\mathcal D = (\mathcal D_1, \mathcal D_2, \cdots, \mathcal D_N)$. diff --git a/source/tests/array_api_strict/descriptor/__init__.py b/source/tests/array_api_strict/descriptor/__init__.py index 6ceb116d85..5667fed858 100644 --- a/source/tests/array_api_strict/descriptor/__init__.py +++ b/source/tests/array_api_strict/descriptor/__init__.py @@ -1 +1,20 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from .dpa1 import ( + DescrptDPA1, +) +from .hybrid import ( + DescrptHybrid, +) +from .se_e2_a import ( + DescrptSeA, +) +from .se_e2_r import ( + DescrptSeR, +) + +__all__ = [ + "DescrptSeA", + "DescrptSeR", + "DescrptDPA1", + "DescrptHybrid", +] diff --git a/source/tests/array_api_strict/descriptor/base_descriptor.py b/source/tests/array_api_strict/descriptor/base_descriptor.py new file mode 100644 index 0000000000..2a31895f55 --- /dev/null +++ b/source/tests/array_api_strict/descriptor/base_descriptor.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.make_base_descriptor import ( + make_base_descriptor, +) + +# no type annotations standard in array api +BaseDescriptor = make_base_descriptor(Any) diff --git a/source/tests/array_api_strict/descriptor/dpa1.py b/source/tests/array_api_strict/descriptor/dpa1.py index ebd688e303..d14444f269 100644 --- a/source/tests/array_api_strict/descriptor/dpa1.py +++ b/source/tests/array_api_strict/descriptor/dpa1.py @@ -27,6 +27,9 @@ from ..utils.type_embed import ( TypeEmbedNet, ) +from .base_descriptor import ( + BaseDescriptor, +) class GatedAttentionLayer(GatedAttentionLayerDP): @@ -72,6 +75,8 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@BaseDescriptor.register("dpa1") +@BaseDescriptor.register("se_atten") class DescrptDPA1(DescrptDPA1DP): def __setattr__(self, name: str, value: Any) -> None: if name == "se_atten": diff --git a/source/tests/array_api_strict/descriptor/hybrid.py b/source/tests/array_api_strict/descriptor/hybrid.py new file mode 100644 index 0000000000..aaaa24ed6b --- /dev/null +++ b/source/tests/array_api_strict/descriptor/hybrid.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP + +from ..common import ( + to_array_api_strict_array, +) +from .base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("hybrid") +class DescrptHybrid(DescrptHybridDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"nlist_cut_idx"}: + value = [to_array_api_strict_array(vv) for vv in value] + elif name in {"descrpt_list"}: + value = [BaseDescriptor.deserialize(vv.serialize()) for vv in value] + + return super().__setattr__(name, value) diff --git a/source/tests/array_api_strict/descriptor/se_e2_a.py b/source/tests/array_api_strict/descriptor/se_e2_a.py index 654b9f8925..17da2aafbf 100644 --- a/source/tests/array_api_strict/descriptor/se_e2_a.py +++ b/source/tests/array_api_strict/descriptor/se_e2_a.py @@ -14,8 +14,13 @@ from ..utils.network import ( NetworkCollection, ) +from .base_descriptor import ( + BaseDescriptor, +) +@BaseDescriptor.register("se_e2_a") +@BaseDescriptor.register("se_a") class DescrptSeA(DescrptSeADP): def __setattr__(self, name: str, value: Any) -> None: if name in {"dstd", "davg"}: diff --git a/source/tests/array_api_strict/descriptor/se_e2_r.py b/source/tests/array_api_strict/descriptor/se_e2_r.py index 839e536cea..b499f4c4c9 100644 --- a/source/tests/array_api_strict/descriptor/se_e2_r.py +++ b/source/tests/array_api_strict/descriptor/se_e2_r.py @@ -14,8 +14,13 @@ from ..utils.network import ( NetworkCollection, ) +from .base_descriptor import ( + BaseDescriptor, +) +@BaseDescriptor.register("se_e2_r") +@BaseDescriptor.register("se_r") class DescrptSeR(DescrptSeRDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"dstd", "davg"}: diff --git a/source/tests/consistent/descriptor/test_hybrid.py b/source/tests/consistent/descriptor/test_hybrid.py index cd52eea5be..c43652b498 100644 --- a/source/tests/consistent/descriptor/test_hybrid.py +++ b/source/tests/consistent/descriptor/test_hybrid.py @@ -12,6 +12,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -28,6 +30,16 @@ from deepmd.tf.descriptor.hybrid import DescrptHybrid as DescrptHybridTF else: DescrptHybridTF = None +if INSTALLED_JAX: + from deepmd.jax.descriptor.hybrid import DescrptHybrid as DescrptHybridJAX +else: + DescrptHybridJAX = None +if INSTALLED_ARRAY_API_STRICT: + from ...array_api_strict.descriptor.hybrid import ( + DescrptHybrid as DescrptHybridStrict, + ) +else: + DescrptHybridStrict = None from deepmd.utils.argcheck import ( descrpt_hybrid_args, ) @@ -68,8 +80,13 @@ def data(self) -> dict: tf_class = DescrptHybridTF dp_class = DescrptHybridDP pt_class = DescrptHybridPT + jax_class = DescrptHybridJAX + array_api_strict_class = DescrptHybridStrict args = descrpt_hybrid_args() + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + def setUp(self): CommonTest.setUp(self) @@ -132,5 +149,23 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return self.eval_array_api_strict_descriptor( + array_api_strict_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_descriptor( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret[0],)