Skip to content

Commit

Permalink
feat(jax): force
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Oct 25, 2024
1 parent 02580c2 commit 0517b59
Show file tree
Hide file tree
Showing 9 changed files with 209 additions and 12 deletions.
30 changes: 25 additions & 5 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,22 +222,42 @@ def call_lower(
extended_coord, fparam=fparam, aparam=aparam
)
del extended_coord, fparam, aparam
atomic_ret = self.atomic_model.forward_common_atomic(
model_predict = self.forward_common_atomic(
cc_ext,
extended_atype,
nlist,
mapping=mapping,
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

def forward_common_atomic(
self,
extended_coord: np.ndarray,
extended_atype: np.ndarray,
nlist: np.ndarray,
mapping: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
do_atomic_virial: bool = False,
):
atomic_ret = self.atomic_model.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
model_predict = fit_output_to_model_output(
return fit_output_to_model_output(
atomic_ret,
self.atomic_output_def(),
cc_ext,
extended_coord,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

forward_lower = call_lower

Expand Down
42 changes: 40 additions & 2 deletions deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from deepmd.dpmodel.output_def import (
FittingOutputDef,
ModelOutputDef,
OutputVariableDef,
get_deriv_name,
get_reduce_name,
)
Expand Down Expand Up @@ -47,6 +48,15 @@ def fit_output_to_model_output(
return model_ret


def get_leading_dims(
vv: np.ndarray,
vdef: OutputVariableDef,
):
"""Get the dimensions of nf x nloc."""
vshape = vv.shape
return list(vshape[: (len(vshape) - len(vdef.shape))])


def communicate_extended_output(
model_ret: dict[str, np.ndarray],
model_output_def: ModelOutputDef,
Expand All @@ -57,6 +67,7 @@ def communicate_extended_output(
local and ghost (extended) atoms to local atoms.
"""
xp = array_api_compat.get_namespace(mapping)
new_ret = {}
for kk in model_output_def.keys_outp():
vv = model_ret[kk]
Expand All @@ -67,8 +78,35 @@ def communicate_extended_output(
new_ret[kk_redu] = model_ret[kk_redu]
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
# name holders
new_ret[kk_derv_r] = None
if model_ret[kk_derv_r] is not None:
mldims = list(mapping.shape)
vldims = get_leading_dims(vv, vdef)
derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005
mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims)))
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
force = xp.zeros(
vldims + derv_r_ext_dims, dtype=vv.dtype, device=vv.device
)
# jax only
if array_api_compat.is_jax_array(force):
from deepmd.jax.env import (
jnp,
)

f_idx = xp.arange(force.size, dtype=xp.int64).reshape(
force.shape
)
new_idx = jnp.take_along_axis(f_idx, mapping, axis=1).ravel()
f_shape = force.shape
force = force.ravel()
force = force.at[new_idx].add(model_ret[kk_derv_r].ravel())
force = force.reshape(f_shape)
else:
raise NotImplementedError("Only JAX arrays are supported.")
new_ret[kk_derv_r] = force
else:
# name holders
new_ret[kk_derv_r] = None
if vdef.c_differentiable:
assert vdef.r_differentiable
kk_derv_r, kk_derv_c = get_deriv_name(kk)
Expand Down
2 changes: 2 additions & 0 deletions deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def _make_env_mat(
# nf x nloc x nnei x 3
diff = coord_r - coord_l
# nf x nloc x nnei
# the grad of JAX vector_norm is NaN at x=0
diff = xp.where(xp.abs(diff) < 1e-30, xp.full_like(diff, 1e-30), diff)
length = xp.linalg.vector_norm(diff, axis=-1, keepdims=True)
# for index 0 nloc atom
length = length + xp.astype(~xp.expand_dims(mask, axis=-1), length.dtype)
Expand Down
1 change: 1 addition & 0 deletions deepmd/jax/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)

jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_debug_nans", True)

__all__ = [
"jax",
Expand Down
82 changes: 82 additions & 0 deletions deepmd/jax/model/base_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,88 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Optional,
)

from deepmd.dpmodel.model.base_model import (
make_base_model,
)
from deepmd.dpmodel.output_def import (
get_deriv_name,
get_reduce_name,
)
from deepmd.jax.env import (
jax,
jnp,
)

BaseModel = make_base_model()


def forward_common_atomic(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
mapping: Optional[jnp.ndarray] = None,
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
do_atomic_virial: bool = False,
):
atomic_ret = self.atomic_model.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
atomic_output_def = self.atomic_output_def()
model_predict = {}
for kk, vv in atomic_ret.items():
model_predict[kk] = vv
vdef = atomic_output_def[kk]
shap = vdef.shape
atom_axis = -(len(shap) + 1)
if vdef.reducible:
kk_redu = get_reduce_name(kk)
model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis)
kk_derv_r, kk_derv_c = get_deriv_name(kk)
if vdef.c_differentiable:
size = 1
for ii in vdef.shape:
size *= ii

split_ff = []
for ss in range(size):

def eval_output(
cc_ext, extended_atype, nlist, mapping, fparam, aparam
):
atomic_ret = self.atomic_model.forward_common_atomic(
cc_ext[None, ...],
extended_atype[None, ...],
nlist[None, ...],
mapping=mapping[None, ...] if mapping is not None else None,
fparam=fparam[None, ...] if fparam is not None else None,
aparam=aparam[None, ...] if aparam is not None else None,
)
return jnp.sum(atomic_ret[kk][0], axis=atom_axis)[ss]

ffi = -jax.vmap(jax.grad(eval_output, argnums=0))(
extended_coord,
extended_atype,
nlist,
mapping,
fparam,
aparam,
)
ffi = ffi[..., None, :]
split_ff.append(ffi)
out_lead_shape = list(extended_coord.shape[:-1]) + vdef.shape
ff = jnp.concatenate(split_ff, axis=-2).reshape(*out_lead_shape, 3)

model_predict[kk_derv_r] = ff
if vdef.c_differentiable:
assert vdef.r_differentiable
model_predict[kk_derv_c] = None
return model_predict
26 changes: 26 additions & 0 deletions deepmd/jax/model/ener_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Optional,
)

from deepmd.dpmodel.model import EnergyModel as EnergyModelDP
Expand All @@ -10,8 +11,12 @@
from deepmd.jax.common import (
flax_module,
)
from deepmd.jax.env import (
jnp,
)
from deepmd.jax.model.base_model import (
BaseModel,
forward_common_atomic,
)


Expand All @@ -22,3 +27,24 @@ def __setattr__(self, name: str, value: Any) -> None:
if name == "atomic_model":
value = DPAtomicModel.deserialize(value.serialize())
return super().__setattr__(name, value)

def forward_common_atomic(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
mapping: Optional[jnp.ndarray] = None,
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
do_atomic_virial: bool = False,
):
return forward_common_atomic(
self,
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
4 changes: 4 additions & 0 deletions source/tests/consistent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
"INSTALLED_ARRAY_API_STRICT",
]

SKIP_FLAG = object()


class CommonTest(ABC):
data: ClassVar[dict]
Expand Down Expand Up @@ -362,6 +364,8 @@ def test_dp_consistent_with_ref(self):
data2 = dp_obj.serialize()
np.testing.assert_equal(data1, data2)
for rr1, rr2 in zip(ret1, ret2):
if rr1 is SKIP_FLAG or rr2 is SKIP_FLAG:
continue
np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol)
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"

Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def build_tf_model(self, obj, natoms, coords, atype, box, suffix):
{},
suffix=suffix,
)
return [ret["energy"], ret["atom_ener"]], {
return [ret["energy"], ret["atom_ener"], ret["force"]], {
t_coord: coords,
t_type: atype,
t_natoms: natoms,
Expand Down
32 changes: 28 additions & 4 deletions source/tests/consistent/model/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
SKIP_FLAG,
CommonTest,
parameterized,
)
Expand Down Expand Up @@ -94,6 +95,21 @@ def data(self) -> dict:
jax_class = EnergyModelJAX
args = model_args()

def get_reference_backend(self):
"""Get the reference backend.
We need a reference backend that can reproduce forces.
"""
if not self.skip_pt:
return self.RefBackend.PT
if not self.skip_tf:
return self.RefBackend.TF
if not self.skip_jax:
return self.RefBackend.JAX
if not self.skip_dp:
return self.RefBackend.DP
raise ValueError("No available reference")

@property
def skip_tf(self):
return (
Expand Down Expand Up @@ -195,11 +211,19 @@ def eval_jax(self, jax_obj: Any) -> Any:
def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
# shape not matched. ravel...
if backend is self.RefBackend.DP:
return (ret["energy_redu"].ravel(), ret["energy"].ravel())
return (ret["energy_redu"].ravel(), ret["energy"].ravel(), SKIP_FLAG)
elif backend is self.RefBackend.PT:
return (ret["energy"].ravel(), ret["atom_energy"].ravel())
return (
ret["energy"].ravel(),
ret["atom_energy"].ravel(),
ret["force"].ravel(),
)
elif backend is self.RefBackend.TF:
return (ret[0].ravel(), ret[1].ravel())
return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel())
elif backend is self.RefBackend.JAX:
return (ret["energy_redu"].ravel(), ret["energy"].ravel())
return (
ret["energy_redu"].ravel(),
ret["energy"].ravel(),
ret["energy_derv_r"].ravel(),
)
raise ValueError(f"Unknown backend: {backend}")

0 comments on commit 0517b59

Please sign in to comment.