Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): energy, dos, dipole, polar, property atomic model & model #4384

Merged
merged 3 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions deepmd/dpmodel/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import array_api_compat
import numpy as np

from deepmd.dpmodel.fitting.polarizability_fitting import (
Expand Down Expand Up @@ -34,29 +35,29 @@ def apply_out_stat(
The atom types. nf x nloc

"""
xp = array_api_compat.array_namespace(atype)
out_bias, out_std = self._fetch_out_stat(self.bias_keys)

if self.fitting_net.shift_diag:
if self.fitting.shift_diag:
nframes, nloc = atype.shape
dtype = out_bias[self.bias_keys[0]].dtype
for kk in self.bias_keys:
ntypes = out_bias[kk].shape[0]
temp = np.zeros(ntypes, dtype=dtype)
temp = np.mean(
np.diagonal(out_bias[kk].reshape(ntypes, 3, 3), axis1=1, axis2=2),
temp = xp.mean(
xp.diagonal(out_bias[kk].reshape(ntypes, 3, 3), axis1=1, axis2=2),
axis=1,
)
modified_bias = temp[atype]

# (nframes, nloc, 1)
modified_bias = (
modified_bias[..., np.newaxis] * (self.fitting_net.scale[atype])
modified_bias[..., xp.newaxis] * (self.fitting.scale[atype])
)

eye = np.eye(3, dtype=dtype)
eye = np.tile(eye, (nframes, nloc, 1, 1))
eye = xp.eye(3, dtype=dtype)
eye = xp.tile(eye, (nframes, nloc, 1, 1))
# (nframes, nloc, 3, 3)
modified_bias = modified_bias[..., np.newaxis] * eye
modified_bias = modified_bias[..., xp.newaxis] * eye

# nf x nloc x odims, out_bias: ntypes x odims
ret[kk] = ret[kk] + modified_bias
Expand Down
11 changes: 11 additions & 0 deletions deepmd/jax/atomic_model/dipole_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.atomic_model.dipole_atomic_model import (
DPDipoleAtomicModel as DPAtomicModelDipoleDP,
)
from deepmd.jax.atomic_model.dp_atomic_model import (
make_jax_dp_atomic_model_from_dpmodel,
)


class DPAtomicModelDipole(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelDipoleDP)):
pass
11 changes: 11 additions & 0 deletions deepmd/jax/atomic_model/dos_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.atomic_model.dos_atomic_model import (
DPDOSAtomicModel as DPAtomicModelDOSDP,
)
from deepmd.jax.atomic_model.dp_atomic_model import (
make_jax_dp_atomic_model_from_dpmodel,
)


class DPAtomicModelDOS(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelDOSDP)):
pass
njzjz marked this conversation as resolved.
Show resolved Hide resolved
78 changes: 50 additions & 28 deletions deepmd/jax/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,53 @@
)


@flax_module
class DPAtomicModel(DPAtomicModelDP):
base_descriptor_cls = BaseDescriptor
"""The base descriptor class."""
base_fitting_cls = BaseFitting
"""The base fitting class."""

def __setattr__(self, name: str, value: Any) -> None:
value = base_atomic_model_set_attr(name, value)
return super().__setattr__(name, value)

def forward_common_atomic(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
mapping: Optional[jnp.ndarray] = None,
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
) -> dict[str, jnp.ndarray]:
return super().forward_common_atomic(
extended_coord,
extended_atype,
jax.lax.stop_gradient(nlist),
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
def make_jax_dp_atomic_model_from_dpmodel(
dpmodel_atomic_model: type[DPAtomicModelDP],
) -> type[DPAtomicModelDP]:
"""Make a JAX backend DP atomic model from a DPModel backend DP atomic model.

Parameters
----------
dpmodel_atomic_model : type[DPAtomicModelDP]
The DPModel backend DP atomic model.

Returns
-------
type[DPAtomicModel]
The JAX backend DP atomic model.
"""

@flax_module
class jax_atomic_model(dpmodel_atomic_model):
base_descriptor_cls = BaseDescriptor
"""The base descriptor class."""
base_fitting_cls = BaseFitting
"""The base fitting class."""

def __setattr__(self, name: str, value: Any) -> None:
value = base_atomic_model_set_attr(name, value)
return super().__setattr__(name, value)

def forward_common_atomic(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
mapping: Optional[jnp.ndarray] = None,
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
) -> dict[str, jnp.ndarray]:
return super().forward_common_atomic(
extended_coord,
extended_atype,
jax.lax.stop_gradient(nlist),
mapping=mapping,
fparam=fparam,
aparam=aparam,
)

return jax_atomic_model


class DPAtomicModel(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelDP)):
pass
11 changes: 11 additions & 0 deletions deepmd/jax/atomic_model/energy_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.atomic_model.energy_atomic_model import (
DPEnergyAtomicModel as DPAtomicModelEnergyDP,
)
from deepmd.jax.atomic_model.dp_atomic_model import (
make_jax_dp_atomic_model_from_dpmodel,
)


class DPAtomicModelEnergy(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelEnergyDP)):
pass
njzjz marked this conversation as resolved.
Show resolved Hide resolved
11 changes: 11 additions & 0 deletions deepmd/jax/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.atomic_model.polar_atomic_model import (
DPPolarAtomicModel as DPAtomicModelPolarDP,
)
from deepmd.jax.atomic_model.dp_atomic_model import (
make_jax_dp_atomic_model_from_dpmodel,
)


class DPAtomicModelPolar(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelPolarDP)):
pass
13 changes: 13 additions & 0 deletions deepmd/jax/atomic_model/property_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.atomic_model.property_atomic_model import (
DPPropertyAtomicModel as DPAtomicModelPropertyDP,
)
from deepmd.jax.atomic_model.dp_atomic_model import (
make_jax_dp_atomic_model_from_dpmodel,
)


class DPAtomicModelProperty(
make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelPropertyDP)
):
pass
16 changes: 16 additions & 0 deletions deepmd/jax/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .dipole_model import (
DipoleModel,
)
from .dos_model import (
DOSModel,
)
from .dp_zbl_model import (
DPZBLLinearEnergyAtomicModel,
)
from .ener_model import (
EnergyModel,
)
from .polar_model import (
PolarModel,
)
from .property_model import (
PropertyModel,
)

__all__ = [
"EnergyModel",
"DPZBLLinearEnergyAtomicModel",
"DOSModel",
"DipoleModel",
"PolarModel",
"PropertyModel",
]
17 changes: 17 additions & 0 deletions deepmd/jax/model/dipole_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

from deepmd.dpmodel.model.dipole_model import DipoleModel as DipoleModelDP
from deepmd.jax.atomic_model.dipole_atomic_model import (
DPAtomicModelDipole,
)
from deepmd.jax.model.base_model import (
BaseModel,
)
from deepmd.jax.model.dp_model import (
make_jax_dp_model_from_dpmodel,
)


@BaseModel.register("dipole")
class DipoleModel(make_jax_dp_model_from_dpmodel(DipoleModelDP, DPAtomicModelDipole)):
pass
16 changes: 16 additions & 0 deletions deepmd/jax/model/dos_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.model.dos_model import DOSModel as DOSModelDP
from deepmd.jax.atomic_model.dos_atomic_model import (
DPAtomicModelDOS,
)
from deepmd.jax.model.base_model import (
BaseModel,
)
from deepmd.jax.model.dp_model import (
make_jax_dp_model_from_dpmodel,
)


@BaseModel.register("dos")
class DOSModel(make_jax_dp_model_from_dpmodel(DOSModelDP, DPAtomicModelDOS)):
pass
86 changes: 86 additions & 0 deletions deepmd/jax/model/dp_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Optional,
)

from deepmd.dpmodel.model import (
DPModelCommon,
)
from deepmd.jax.atomic_model.dp_atomic_model import (
DPAtomicModel,
)
from deepmd.jax.common import (
flax_module,
)
from deepmd.jax.env import (
jax,
jnp,
)
from deepmd.jax.model.base_model import (
forward_common_atomic,
)


def make_jax_dp_model_from_dpmodel(
dpmodel_model: type[DPModelCommon], jax_atomicmodel: type[DPAtomicModel]
) -> type[DPModelCommon]:
"""Make a JAX backend DP model from a DPModel backend DP model.

Parameters
----------
dpmodel_model : type[DPModelCommon]
The DPModel backend DP model.
jax_atomicmodel : type[DPAtomicModel]
The JAX backend DP atomic model.

Returns
-------
type[DPModelCommon]
The JAX backend DP model.
"""

@flax_module
class jax_model(dpmodel_model):
def __setattr__(self, name: str, value: Any) -> None:
if name == "atomic_model":
value = jax_atomicmodel.deserialize(value.serialize())
return super().__setattr__(name, value)

def forward_common_atomic(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
mapping: Optional[jnp.ndarray] = None,
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
do_atomic_virial: bool = False,
):
return forward_common_atomic(
self,
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)

def format_nlist(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
extra_nlist_sort: bool = False,
):
return dpmodel_model.format_nlist(
self,
jax.lax.stop_gradient(extended_coord),
extended_atype,
nlist,
extra_nlist_sort=extra_nlist_sort,
)

return jax_model
Loading