Skip to content

Commit

Permalink
feat(jax/array-api): DPA-2 (#4294)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced new classes for enhanced descriptor functionality,
including `DescrptDPA2`, `DescrptBlockRepformers`, and
`DescrptBlockSeTTebd`.
- Added serialization and deserialization methods for better state
management of descriptor objects.

- **Improvements**
- Enhanced compatibility with various array backends through the
integration of `array_api_compat`.
- Refactored existing methods to utilize new array API functions for
improved performance.
- Updated documentation to reflect JAX as a supported backend alongside
PyTorch.

- **Bug Fixes**
- Updated handling of attributes in several classes to ensure correct
deserialization and type safety.

- **Tests**
- Enhanced testing capabilities for JAX and Array API Strict backend
integration, including conditional imports and new evaluation methods.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com>
  • Loading branch information
4 people authored Nov 2, 2024
1 parent 6bc730f commit 25bb821
Show file tree
Hide file tree
Showing 9 changed files with 616 additions and 118 deletions.
37 changes: 25 additions & 12 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
NativeOP,
)
from deepmd.dpmodel.array_api import (
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils import (
EnvMat,
NetworkCollection,
Expand Down Expand Up @@ -787,9 +794,10 @@ def call(
The smooth switch function. shape: nf x nloc x nnei
"""
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
use_three_body = self.use_three_body
nframes, nloc, nnei = nlist.shape
nall = coord_ext.reshape(nframes, -1).shape[1] // 3
nall = xp.reshape(coord_ext, (nframes, -1)).shape[1] // 3
# nlists
nlist_dict = build_multiple_neighbor_list(
coord_ext,
Expand All @@ -798,7 +806,10 @@ def call(
self.nsel_list,
)
# repinit
g1_ext = self.type_embedding.call()[atype_ext]
g1_ext = xp.reshape(
xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0),
(nframes, nall, self.tebd_dim),
)
g1_inp = g1_ext[:, :nloc, :]
g1, _, _, _, _ = self.repinit(
nlist_dict[
Expand All @@ -823,16 +834,18 @@ def call(
g1_ext,
mapping,
)
g1 = np.concatenate([g1, g1_three_body], axis=-1)
g1 = xp.concat([g1, g1_three_body], axis=-1)
# linear to change shape
g1 = self.g1_shape_tranform(g1)
if self.add_tebd_to_repinit_out:
assert self.tebd_transform is not None
g1 = g1 + self.tebd_transform(g1_inp)
# mapping g1
assert mapping is not None
mapping_ext = np.tile(mapping.reshape(nframes, nall, 1), (1, 1, g1.shape[-1]))
g1_ext = np.take_along_axis(g1, mapping_ext, axis=1)
mapping_ext = xp.tile(
xp.reshape(mapping, (nframes, nall, 1)), (1, 1, g1.shape[-1])
)
g1_ext = xp_take_along_axis(g1, mapping_ext, axis=1)
# repformer
g1, g2, h2, rot_mat, sw = self.repformers(
nlist_dict[
Expand All @@ -846,7 +859,7 @@ def call(
mapping,
)
if self.concat_output_tebd:
g1 = np.concatenate([g1, g1_inp], axis=-1)
g1 = xp.concat([g1, g1_inp], axis=-1)
return g1, rot_mat, g2, h2, sw

def serialize(self) -> dict:
Expand Down Expand Up @@ -883,8 +896,8 @@ def serialize(self) -> dict:
"embeddings": repinit.embeddings.serialize(),
"env_mat": EnvMat(repinit.rcut, repinit.rcut_smth).serialize(),
"@variables": {
"davg": repinit["davg"],
"dstd": repinit["dstd"],
"davg": to_numpy_array(repinit["davg"]),
"dstd": to_numpy_array(repinit["dstd"]),
},
}
if repinit.tebd_input_mode in ["strip"]:
Expand All @@ -896,8 +909,8 @@ def serialize(self) -> dict:
"repformer_layers": [layer.serialize() for layer in repformers.layers],
"env_mat": EnvMat(repformers.rcut, repformers.rcut_smth).serialize(),
"@variables": {
"davg": repformers["davg"],
"dstd": repformers["dstd"],
"davg": to_numpy_array(repformers["davg"]),
"dstd": to_numpy_array(repformers["dstd"]),
},
}
data.update(
Expand All @@ -913,8 +926,8 @@ def serialize(self) -> dict:
repinit_three_body.rcut, repinit_three_body.rcut_smth
).serialize(),
"@variables": {
"davg": repinit_three_body["davg"],
"dstd": repinit_three_body["dstd"],
"davg": to_numpy_array(repinit_three_body["davg"]),
"dstd": to_numpy_array(repinit_three_body["dstd"]),
},
}
if repinit_three_body.tebd_input_mode in ["strip"]:
Expand Down
Loading

0 comments on commit 25bb821

Please sign in to comment.