Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Oct 25, 2024
1 parent 004b89a commit b9eefd3
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 20 deletions.
32 changes: 14 additions & 18 deletions deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,16 @@ def communicate_extended_output(
)
# jax only
if array_api_compat.is_jax_array(force):
from deepmd.jax.env import (
jnp,
from deepmd.jax.common import (
scatter_sum,
)

f_idx = xp.arange(force.size, dtype=xp.int64).reshape(
force.shape
force = scatter_sum(
force,
1,
mapping,
model_ret[kk_derv_r],
)
new_idx = jnp.take_along_axis(f_idx, mapping, axis=1).ravel()
f_shape = force.shape
force = force.ravel()
force = force.at[new_idx].add(model_ret[kk_derv_r].ravel())
force = force.reshape(f_shape)
else:
raise NotImplementedError("Only JAX arrays are supported.")

Check warning on line 116 in deepmd/dpmodel/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/transform_output.py#L116

Added line #L116 was not covered by tests
new_ret[kk_derv_r] = force
Expand All @@ -132,18 +130,16 @@ def communicate_extended_output(
)
# jax only
if array_api_compat.is_jax_array(virial):
from deepmd.jax.env import (
jnp,
from deepmd.jax.common import (
scatter_sum,
)

v_idx = xp.arange(virial.size, dtype=xp.int64).reshape(
virial.shape
virial = scatter_sum(
virial,
1,
mapping,
model_ret[kk_derv_c],
)
new_idx = jnp.take_along_axis(v_idx, mapping, axis=1).ravel()
v_shape = virial.shape
virial = virial.ravel()
virial = virial.at[new_idx].add(model_ret[kk_derv_c].ravel())
virial = virial.reshape(v_shape)
else:
raise NotImplementedError("Only JAX arrays are supported.")

Check warning on line 144 in deepmd/dpmodel/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/transform_output.py#L144

Added line #L144 was not covered by tests
new_ret[kk_derv_c] = virial
Expand Down
10 changes: 10 additions & 0 deletions deepmd/jax/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,13 @@ def __dlpack__(self, *args, **kwargs):

def __dlpack_device__(self, *args, **kwargs):
return self.value.__dlpack_device__(*args, **kwargs)


def scatter_sum(input, dim, index: jnp.ndarray, src: jnp.ndarray) -> jnp.ndarray:
"""Reduces all values from the src tensor to the indices specified in the index tensor."""
idx = jnp.arange(input.size, dtype=jnp.int64).reshape(input.shape)
new_idx = jnp.take_along_axis(idx, index, axis=dim).ravel()
shape = input.shape
input = input.ravel()
input = input.at[new_idx].add(src.ravel())
return input.reshape(shape)
13 changes: 11 additions & 2 deletions deepmd/jax/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,16 @@ def forward_common_atomic(
for ss in range(size):

def eval_output(
cc_ext, extended_atype, nlist, mapping, fparam, aparam
cc_ext,
extended_atype,
nlist,
mapping,
fparam,
aparam,
*,
_kk=kk,
_ss=ss,
_atom_axis=atom_axis,
):
atomic_ret = self.atomic_model.forward_common_atomic(
cc_ext[None, ...],
Expand All @@ -67,7 +76,7 @@ def eval_output(
fparam=fparam[None, ...] if fparam is not None else None,
aparam=aparam[None, ...] if aparam is not None else None,
)
return jnp.sum(atomic_ret[kk][0], axis=atom_axis)[ss]
return jnp.sum(atomic_ret[_kk][0], axis=_atom_axis)[_ss]

ffi = -jax.vmap(jax.grad(eval_output, argnums=0))(
extended_coord,
Expand Down

0 comments on commit b9eefd3

Please sign in to comment.