Skip to content

Commit

Permalink
Merge branch 'devel' into add_paddle_backend
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Oct 30, 2024
2 parents cbc9c65 + d165fee commit afd4746
Show file tree
Hide file tree
Showing 38 changed files with 1,500 additions and 247 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ For more information, check the [documentation](https://deepmd.readthedocs.io/).

### Highlighted features

- **interfaced with multiple backends**, including TensorFlow, PyTorch and Paddle, the most popular deep learning frameworks, making the training process highly automatic and efficient.
- **interfaced with multiple backends**, including TensorFlow, PyTorch, JAX and Paddle the most popular deep learning frameworks, making the training process highly automatic and efficient.
- **interfaced with high-performance classical MD and quantum (path-integral) MD packages**, including LAMMPS, i-PI, AMBER, CP2K, GROMACS, OpenMM, and ABUCUS.
- **implements the Deep Potential series models**, which have been successfully applied to finite and extended systems, including organic molecules, metals, semiconductors, insulators, etc.
- **implements MPI and GPU supports**, making it highly efficient for high-performance parallel and distributed computing.
Expand Down Expand Up @@ -72,7 +72,7 @@ See [our latest paper](https://doi.org/10.1063/5.0155600) for details of all fea

#### v3

- Multiple backends supported. Add PyTorch and Paddle backend.
- Multiple backends supported. Add PyTorch, JAX and Paddle backends.
- The DPA-2 model.

## Install and use DeePMD-kit
Expand Down
20 changes: 14 additions & 6 deletions deepmd/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ class JAXBackend(Backend):
"""The formal name of the backend."""
features: ClassVar[Backend.Feature] = (
Backend.Feature.IO
# Backend.Feature.ENTRY_POINT
# | Backend.Feature.DEEP_EVAL
# | Backend.Feature.NEIGHBOR_STAT
| Backend.Feature.ENTRY_POINT
| Backend.Feature.DEEP_EVAL
| Backend.Feature.NEIGHBOR_STAT
)
"""The features of the backend."""
suffixes: ClassVar[list[str]] = [".jax"]
suffixes: ClassVar[list[str]] = [".hlo", ".jax"]
"""The suffixes of the backend."""

def is_available(self) -> bool:
Expand Down Expand Up @@ -71,7 +71,11 @@ def deep_eval(self) -> type["DeepEvalBackend"]:
type[DeepEvalBackend]
The Deep Eval backend of the backend.
"""
raise NotImplementedError
from deepmd.jax.infer.deep_eval import (
DeepEval,
)

return DeepEval

@property
def neighbor_stat(self) -> type["NeighborStat"]:
Expand All @@ -82,7 +86,11 @@ def neighbor_stat(self) -> type["NeighborStat"]:
type[NeighborStat]
The neighbor statistics of the backend.
"""
raise NotImplementedError
from deepmd.jax.utils.neighbor_stat import (
NeighborStat,
)

return NeighborStat

@property
def serialize_hook(self) -> Callable[[str], dict]:
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def call(
coord_ext, atype_ext, nlist, self.davg, self.dstd
)
nf, nloc, nnei, _ = rr.shape
sec = xp.asarray(self.sel_cumsum)
sec = self.sel_cumsum

ng = self.neuron[-1]
gr = xp.zeros([nf * nloc, ng, 4], dtype=self.dstd.dtype)
Expand Down
160 changes: 124 additions & 36 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Optional,
)

Expand Down Expand Up @@ -39,6 +40,95 @@
)


def model_call_from_call_lower(
*, # enforce keyword-only arguments
call_lower: Callable[
[
np.ndarray,
np.ndarray,
np.ndarray,
Optional[np.ndarray],
Optional[np.ndarray],
bool,
],
dict[str, np.ndarray],
],
rcut: float,
sel: list[int],
mixed_types: bool,
model_output_def: ModelOutputDef,
coord: np.ndarray,
atype: np.ndarray,
box: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
do_atomic_virial: bool = False,
):
"""Return model prediction from lower interface.
Parameters
----------
coord
The coordinates of the atoms.
shape: nf x (nloc x 3)
atype
The type of atoms. shape: nf x nloc
box
The simulation box. shape: nf x 9
fparam
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda
do_atomic_virial
If calculate the atomic virial.
Returns
-------
ret_dict
The result dict of type dict[str,np.ndarray].
The keys are defined by the `ModelOutputDef`.
"""
nframes, nloc = atype.shape[:2]
cc, bb, fp, ap = coord, box, fparam, aparam
del coord, box, fparam, aparam
if bb is not None:
coord_normalized = normalize_coord(
cc.reshape(nframes, nloc, 3),
bb.reshape(nframes, 3, 3),
)
else:
coord_normalized = cc.copy()
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
coord_normalized, atype, bb, rcut
)
nlist = build_neighbor_list(
extended_coord,
extended_atype,
nloc,
rcut,
sel,
distinguish_types=not mixed_types,
)
extended_coord = extended_coord.reshape(nframes, -1, 3)
model_predict_lower = call_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
)
model_predict = communicate_extended_output(
model_predict_lower,
model_output_def,
mapping,
do_atomic_virial=do_atomic_virial,
)
return model_predict


def make_model(T_AtomicModel: type[BaseAtomicModel]):
"""Make a model as a derived class of an atomic model.
Expand Down Expand Up @@ -130,45 +220,23 @@ def call(
The keys are defined by the `ModelOutputDef`.
"""
nframes, nloc = atype.shape[:2]
cc, bb, fp, ap, input_prec = self.input_type_cast(
coord, box=box, fparam=fparam, aparam=aparam
)
del coord, box, fparam, aparam
if bb is not None:
coord_normalized = normalize_coord(
cc.reshape(nframes, nloc, 3),
bb.reshape(nframes, 3, 3),
)
else:
coord_normalized = cc.copy()
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
coord_normalized, atype, bb, self.get_rcut()
)
nlist = build_neighbor_list(
extended_coord,
extended_atype,
nloc,
self.get_rcut(),
self.get_sel(),
distinguish_types=not self.mixed_types(),
)
extended_coord = extended_coord.reshape(nframes, -1, 3)
model_predict_lower = self.call_lower(
extended_coord,
extended_atype,
nlist,
mapping,
model_predict = model_call_from_call_lower(
call_lower=self.call_lower,
rcut=self.get_rcut(),
sel=self.get_sel(),
mixed_types=self.mixed_types(),
model_output_def=self.model_output_def(),
coord=cc,
atype=atype,
box=bb,
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
)
model_predict = communicate_extended_output(
model_predict_lower,
self.model_output_def(),
mapping,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

Expand Down Expand Up @@ -222,22 +290,42 @@ def call_lower(
extended_coord, fparam=fparam, aparam=aparam
)
del extended_coord, fparam, aparam
atomic_ret = self.atomic_model.forward_common_atomic(
model_predict = self.forward_common_atomic(
cc_ext,
extended_atype,
nlist,
mapping=mapping,
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
)
model_predict = fit_output_to_model_output(
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

def forward_common_atomic(
self,
extended_coord: np.ndarray,
extended_atype: np.ndarray,
nlist: np.ndarray,
mapping: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
do_atomic_virial: bool = False,
):
atomic_ret = self.atomic_model.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
return fit_output_to_model_output(
atomic_ret,
self.atomic_output_def(),
cc_ext,
extended_coord,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

forward_lower = call_lower

Expand Down
84 changes: 78 additions & 6 deletions deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from deepmd.dpmodel.output_def import (
FittingOutputDef,
ModelOutputDef,
OutputVariableDef,
get_deriv_name,
get_reduce_name,
)
Expand Down Expand Up @@ -47,6 +48,28 @@ def fit_output_to_model_output(
return model_ret


def get_leading_dims(
vv: np.ndarray,
vdef: OutputVariableDef,
):
"""Get the dimensions of nf x nloc.
Parameters
----------
vv : np.ndarray
The input array from which to compute the leading dimensions.
vdef : OutputVariableDef
The output variable definition containing the shape to exclude from `vv`.
Returns
-------
list
A list of leading dimensions of `vv`, excluding the last `len(vdef.shape)` dimensions.
"""
vshape = vv.shape
return list(vshape[: (len(vshape) - len(vdef.shape))])


def communicate_extended_output(
model_ret: dict[str, np.ndarray],
model_output_def: ModelOutputDef,
Expand All @@ -57,6 +80,7 @@ def communicate_extended_output(
local and ghost (extended) atoms to local atoms.
"""
xp = array_api_compat.get_namespace(mapping)
new_ret = {}
for kk in model_output_def.keys_outp():
vv = model_ret[kk]
Expand All @@ -65,15 +89,63 @@ def communicate_extended_output(
if vdef.reducible:
kk_redu = get_reduce_name(kk)
new_ret[kk_redu] = model_ret[kk_redu]
kk_derv_r, kk_derv_c = get_deriv_name(kk)
mldims = list(mapping.shape)
vldims = get_leading_dims(vv, vdef)
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
# name holders
new_ret[kk_derv_r] = None
if model_ret[kk_derv_r] is not None:
derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005
mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims)))
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype)
# jax only
if array_api_compat.is_jax_array(force):
from deepmd.jax.common import (
scatter_sum,
)

force = scatter_sum(
force,
1,
mapping,
model_ret[kk_derv_r],
)
else:
raise NotImplementedError("Only JAX arrays are supported.")
new_ret[kk_derv_r] = force
else:
# name holders
new_ret[kk_derv_r] = None
if vdef.c_differentiable:
assert vdef.r_differentiable
kk_derv_r, kk_derv_c = get_deriv_name(kk)
new_ret[kk_derv_c] = None
new_ret[kk_derv_c + "_redu"] = None
if model_ret[kk_derv_c] is not None:
derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005
mapping = xp.tile(
mapping, [1] * (len(mldims) + len(vdef.shape)) + [3]
)
virial = xp.zeros(
vldims + derv_c_ext_dims,
dtype=vv.dtype,
)
# jax only
if array_api_compat.is_jax_array(virial):
from deepmd.jax.common import (
scatter_sum,
)

virial = scatter_sum(
virial,
1,
mapping,
model_ret[kk_derv_c],
)
else:
raise NotImplementedError("Only JAX arrays are supported.")
new_ret[kk_derv_c] = virial
new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1)
else:
new_ret[kk_derv_c] = None
new_ret[kk_derv_c + "_redu"] = None
if not do_atomic_virial:
# pop atomic virial, because it is not correctly calculated.
new_ret.pop(kk_derv_c)
Expand Down
Loading

0 comments on commit afd4746

Please sign in to comment.