-
Notifications
You must be signed in to change notification settings - Fork 526
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
- Loading branch information
Showing
9 changed files
with
209 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,88 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from typing import ( | ||
Optional, | ||
) | ||
|
||
from deepmd.dpmodel.model.base_model import ( | ||
make_base_model, | ||
) | ||
from deepmd.dpmodel.output_def import ( | ||
get_deriv_name, | ||
get_reduce_name, | ||
) | ||
from deepmd.jax.env import ( | ||
jax, | ||
jnp, | ||
) | ||
|
||
BaseModel = make_base_model() | ||
|
||
|
||
def forward_common_atomic( | ||
self, | ||
extended_coord: jnp.ndarray, | ||
extended_atype: jnp.ndarray, | ||
nlist: jnp.ndarray, | ||
mapping: Optional[jnp.ndarray] = None, | ||
fparam: Optional[jnp.ndarray] = None, | ||
aparam: Optional[jnp.ndarray] = None, | ||
do_atomic_virial: bool = False, | ||
): | ||
atomic_ret = self.atomic_model.forward_common_atomic( | ||
extended_coord, | ||
extended_atype, | ||
nlist, | ||
mapping=mapping, | ||
fparam=fparam, | ||
aparam=aparam, | ||
) | ||
atomic_output_def = self.atomic_output_def() | ||
model_predict = {} | ||
for kk, vv in atomic_ret.items(): | ||
model_predict[kk] = vv | ||
vdef = atomic_output_def[kk] | ||
shap = vdef.shape | ||
atom_axis = -(len(shap) + 1) | ||
if vdef.reducible: | ||
kk_redu = get_reduce_name(kk) | ||
model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis) | ||
kk_derv_r, kk_derv_c = get_deriv_name(kk) | ||
if vdef.c_differentiable: | ||
size = 1 | ||
for ii in vdef.shape: | ||
size *= ii | ||
|
||
split_ff = [] | ||
for ss in range(size): | ||
|
||
def eval_output( | ||
cc_ext, extended_atype, nlist, mapping, fparam, aparam | ||
): | ||
atomic_ret = self.atomic_model.forward_common_atomic( | ||
cc_ext[None, ...], | ||
extended_atype[None, ...], | ||
nlist[None, ...], | ||
mapping=mapping[None, ...] if mapping is not None else None, | ||
fparam=fparam[None, ...] if fparam is not None else None, | ||
aparam=aparam[None, ...] if aparam is not None else None, | ||
) | ||
return jnp.sum(atomic_ret[kk][0], axis=atom_axis)[ss] | ||
|
||
ffi = -jax.vmap(jax.grad(eval_output, argnums=0))( | ||
extended_coord, | ||
extended_atype, | ||
nlist, | ||
mapping, | ||
fparam, | ||
aparam, | ||
) | ||
ffi = ffi[..., None, :] | ||
split_ff.append(ffi) | ||
out_lead_shape = list(extended_coord.shape[:-1]) + vdef.shape | ||
ff = jnp.concatenate(split_ff, axis=-2).reshape(*out_lead_shape, 3) | ||
|
||
model_predict[kk_derv_r] = ff | ||
if vdef.c_differentiable: | ||
assert vdef.r_differentiable | ||
model_predict[kk_derv_c] = None | ||
return model_predict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters