Skip to content

Commit

Permalink
feat(jax): energy, dos, dipole, polar, property atomic model & model (#…
Browse files Browse the repository at this point in the history
…4384)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced several new atomic model classes: `DPAtomicModelDipole`,
`DPAtomicModelDOS`, `DPAtomicModelEnergy`, `DPAtomicModelPolar`, and
`DPAtomicModelProperty`.
- Added new model classes: `DipoleModel`, `DOSModel`, `PolarModel`, and
`PropertyModel` for enhanced functionalities.
- Implemented a new function to create JAX-compatible models from
existing DP models, improving integration with JAX.

- **Bug Fixes**
- Enhanced test suite to support JAX backend, ensuring compatibility and
flexibility in testing.

- **Documentation**
  - Updated public API to include new models and functionalities.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Nov 21, 2024
1 parent 4334377 commit e7925f3
Show file tree
Hide file tree
Showing 19 changed files with 910 additions and 95 deletions.
17 changes: 9 additions & 8 deletions deepmd/dpmodel/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import array_api_compat
import numpy as np

from deepmd.dpmodel.fitting.polarizability_fitting import (
Expand Down Expand Up @@ -34,29 +35,29 @@ def apply_out_stat(
The atom types. nf x nloc
"""
xp = array_api_compat.array_namespace(atype)
out_bias, out_std = self._fetch_out_stat(self.bias_keys)

if self.fitting_net.shift_diag:
if self.fitting.shift_diag:
nframes, nloc = atype.shape
dtype = out_bias[self.bias_keys[0]].dtype
for kk in self.bias_keys:
ntypes = out_bias[kk].shape[0]
temp = np.zeros(ntypes, dtype=dtype)
temp = np.mean(
np.diagonal(out_bias[kk].reshape(ntypes, 3, 3), axis1=1, axis2=2),
temp = xp.mean(
xp.diagonal(out_bias[kk].reshape(ntypes, 3, 3), axis1=1, axis2=2),
axis=1,
)
modified_bias = temp[atype]

# (nframes, nloc, 1)
modified_bias = (
modified_bias[..., np.newaxis] * (self.fitting_net.scale[atype])
modified_bias[..., xp.newaxis] * (self.fitting.scale[atype])
)

eye = np.eye(3, dtype=dtype)
eye = np.tile(eye, (nframes, nloc, 1, 1))
eye = xp.eye(3, dtype=dtype)
eye = xp.tile(eye, (nframes, nloc, 1, 1))
# (nframes, nloc, 3, 3)
modified_bias = modified_bias[..., np.newaxis] * eye
modified_bias = modified_bias[..., xp.newaxis] * eye

# nf x nloc x odims, out_bias: ntypes x odims
ret[kk] = ret[kk] + modified_bias
Expand Down
11 changes: 11 additions & 0 deletions deepmd/jax/atomic_model/dipole_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.atomic_model.dipole_atomic_model import (
DPDipoleAtomicModel as DPAtomicModelDipoleDP,
)
from deepmd.jax.atomic_model.dp_atomic_model import (
make_jax_dp_atomic_model_from_dpmodel,
)


class DPAtomicModelDipole(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelDipoleDP)):
pass
11 changes: 11 additions & 0 deletions deepmd/jax/atomic_model/dos_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.atomic_model.dos_atomic_model import (
DPDOSAtomicModel as DPAtomicModelDOSDP,
)
from deepmd.jax.atomic_model.dp_atomic_model import (
make_jax_dp_atomic_model_from_dpmodel,
)


class DPAtomicModelDOS(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelDOSDP)):
pass
78 changes: 50 additions & 28 deletions deepmd/jax/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,53 @@
)


@flax_module
class DPAtomicModel(DPAtomicModelDP):
base_descriptor_cls = BaseDescriptor
"""The base descriptor class."""
base_fitting_cls = BaseFitting
"""The base fitting class."""

def __setattr__(self, name: str, value: Any) -> None:
value = base_atomic_model_set_attr(name, value)
return super().__setattr__(name, value)

def forward_common_atomic(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
mapping: Optional[jnp.ndarray] = None,
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
) -> dict[str, jnp.ndarray]:
return super().forward_common_atomic(
extended_coord,
extended_atype,
jax.lax.stop_gradient(nlist),
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
def make_jax_dp_atomic_model_from_dpmodel(
dpmodel_atomic_model: type[DPAtomicModelDP],
) -> type[DPAtomicModelDP]:
"""Make a JAX backend DP atomic model from a DPModel backend DP atomic model.
Parameters
----------
dpmodel_atomic_model : type[DPAtomicModelDP]
The DPModel backend DP atomic model.
Returns
-------
type[DPAtomicModel]
The JAX backend DP atomic model.
"""

@flax_module
class jax_atomic_model(dpmodel_atomic_model):
base_descriptor_cls = BaseDescriptor
"""The base descriptor class."""
base_fitting_cls = BaseFitting
"""The base fitting class."""

def __setattr__(self, name: str, value: Any) -> None:
value = base_atomic_model_set_attr(name, value)
return super().__setattr__(name, value)

def forward_common_atomic(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
mapping: Optional[jnp.ndarray] = None,
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
) -> dict[str, jnp.ndarray]:
return super().forward_common_atomic(
extended_coord,
extended_atype,
jax.lax.stop_gradient(nlist),
mapping=mapping,
fparam=fparam,
aparam=aparam,
)

return jax_atomic_model


class DPAtomicModel(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelDP)):
pass
11 changes: 11 additions & 0 deletions deepmd/jax/atomic_model/energy_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.atomic_model.energy_atomic_model import (
DPEnergyAtomicModel as DPAtomicModelEnergyDP,
)
from deepmd.jax.atomic_model.dp_atomic_model import (
make_jax_dp_atomic_model_from_dpmodel,
)


class DPAtomicModelEnergy(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelEnergyDP)):
pass
11 changes: 11 additions & 0 deletions deepmd/jax/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.atomic_model.polar_atomic_model import (
DPPolarAtomicModel as DPAtomicModelPolarDP,
)
from deepmd.jax.atomic_model.dp_atomic_model import (
make_jax_dp_atomic_model_from_dpmodel,
)


class DPAtomicModelPolar(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelPolarDP)):
pass
13 changes: 13 additions & 0 deletions deepmd/jax/atomic_model/property_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.atomic_model.property_atomic_model import (
DPPropertyAtomicModel as DPAtomicModelPropertyDP,
)
from deepmd.jax.atomic_model.dp_atomic_model import (
make_jax_dp_atomic_model_from_dpmodel,
)


class DPAtomicModelProperty(
make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelPropertyDP)
):
pass
16 changes: 16 additions & 0 deletions deepmd/jax/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .dipole_model import (
DipoleModel,
)
from .dos_model import (
DOSModel,
)
from .dp_zbl_model import (
DPZBLLinearEnergyAtomicModel,
)
from .ener_model import (
EnergyModel,
)
from .polar_model import (
PolarModel,
)
from .property_model import (
PropertyModel,
)

__all__ = [
"EnergyModel",
"DPZBLLinearEnergyAtomicModel",
"DOSModel",
"DipoleModel",
"PolarModel",
"PropertyModel",
]
17 changes: 17 additions & 0 deletions deepmd/jax/model/dipole_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

from deepmd.dpmodel.model.dipole_model import DipoleModel as DipoleModelDP
from deepmd.jax.atomic_model.dipole_atomic_model import (
DPAtomicModelDipole,
)
from deepmd.jax.model.base_model import (
BaseModel,
)
from deepmd.jax.model.dp_model import (
make_jax_dp_model_from_dpmodel,
)


@BaseModel.register("dipole")
class DipoleModel(make_jax_dp_model_from_dpmodel(DipoleModelDP, DPAtomicModelDipole)):
pass
16 changes: 16 additions & 0 deletions deepmd/jax/model/dos_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.model.dos_model import DOSModel as DOSModelDP
from deepmd.jax.atomic_model.dos_atomic_model import (
DPAtomicModelDOS,
)
from deepmd.jax.model.base_model import (
BaseModel,
)
from deepmd.jax.model.dp_model import (
make_jax_dp_model_from_dpmodel,
)


@BaseModel.register("dos")
class DOSModel(make_jax_dp_model_from_dpmodel(DOSModelDP, DPAtomicModelDOS)):
pass
86 changes: 86 additions & 0 deletions deepmd/jax/model/dp_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Optional,
)

from deepmd.dpmodel.model import (
DPModelCommon,
)
from deepmd.jax.atomic_model.dp_atomic_model import (
DPAtomicModel,
)
from deepmd.jax.common import (
flax_module,
)
from deepmd.jax.env import (
jax,
jnp,
)
from deepmd.jax.model.base_model import (
forward_common_atomic,
)


def make_jax_dp_model_from_dpmodel(
dpmodel_model: type[DPModelCommon], jax_atomicmodel: type[DPAtomicModel]
) -> type[DPModelCommon]:
"""Make a JAX backend DP model from a DPModel backend DP model.
Parameters
----------
dpmodel_model : type[DPModelCommon]
The DPModel backend DP model.
jax_atomicmodel : type[DPAtomicModel]
The JAX backend DP atomic model.
Returns
-------
type[DPModelCommon]
The JAX backend DP model.
"""

@flax_module
class jax_model(dpmodel_model):
def __setattr__(self, name: str, value: Any) -> None:
if name == "atomic_model":
value = jax_atomicmodel.deserialize(value.serialize())
return super().__setattr__(name, value)

def forward_common_atomic(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
mapping: Optional[jnp.ndarray] = None,
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
do_atomic_virial: bool = False,
):
return forward_common_atomic(
self,
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)

def format_nlist(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
extra_nlist_sort: bool = False,
):
return dpmodel_model.format_nlist(
self,
jax.lax.stop_gradient(extended_coord),
extended_atype,
nlist,
extra_nlist_sort=extra_nlist_sort,
)

return jax_model
Loading

0 comments on commit e7925f3

Please sign in to comment.