-
Notifications
You must be signed in to change notification settings - Fork 525
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(jax): energy, dos, dipole, polar, property atomic model & model (#…
…4384) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced several new atomic model classes: `DPAtomicModelDipole`, `DPAtomicModelDOS`, `DPAtomicModelEnergy`, `DPAtomicModelPolar`, and `DPAtomicModelProperty`. - Added new model classes: `DipoleModel`, `DOSModel`, `PolarModel`, and `PropertyModel` for enhanced functionalities. - Implemented a new function to create JAX-compatible models from existing DP models, improving integration with JAX. - **Bug Fixes** - Enhanced test suite to support JAX backend, ensuring compatibility and flexibility in testing. - **Documentation** - Updated public API to include new models and functionalities. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
- Loading branch information
Showing
19 changed files
with
910 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from deepmd.dpmodel.atomic_model.dipole_atomic_model import ( | ||
DPDipoleAtomicModel as DPAtomicModelDipoleDP, | ||
) | ||
from deepmd.jax.atomic_model.dp_atomic_model import ( | ||
make_jax_dp_atomic_model_from_dpmodel, | ||
) | ||
|
||
|
||
class DPAtomicModelDipole(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelDipoleDP)): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from deepmd.dpmodel.atomic_model.dos_atomic_model import ( | ||
DPDOSAtomicModel as DPAtomicModelDOSDP, | ||
) | ||
from deepmd.jax.atomic_model.dp_atomic_model import ( | ||
make_jax_dp_atomic_model_from_dpmodel, | ||
) | ||
|
||
|
||
class DPAtomicModelDOS(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelDOSDP)): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from deepmd.dpmodel.atomic_model.energy_atomic_model import ( | ||
DPEnergyAtomicModel as DPAtomicModelEnergyDP, | ||
) | ||
from deepmd.jax.atomic_model.dp_atomic_model import ( | ||
make_jax_dp_atomic_model_from_dpmodel, | ||
) | ||
|
||
|
||
class DPAtomicModelEnergy(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelEnergyDP)): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from deepmd.dpmodel.atomic_model.polar_atomic_model import ( | ||
DPPolarAtomicModel as DPAtomicModelPolarDP, | ||
) | ||
from deepmd.jax.atomic_model.dp_atomic_model import ( | ||
make_jax_dp_atomic_model_from_dpmodel, | ||
) | ||
|
||
|
||
class DPAtomicModelPolar(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelPolarDP)): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from deepmd.dpmodel.atomic_model.property_atomic_model import ( | ||
DPPropertyAtomicModel as DPAtomicModelPropertyDP, | ||
) | ||
from deepmd.jax.atomic_model.dp_atomic_model import ( | ||
make_jax_dp_atomic_model_from_dpmodel, | ||
) | ||
|
||
|
||
class DPAtomicModelProperty( | ||
make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelPropertyDP) | ||
): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,28 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from .dipole_model import ( | ||
DipoleModel, | ||
) | ||
from .dos_model import ( | ||
DOSModel, | ||
) | ||
from .dp_zbl_model import ( | ||
DPZBLLinearEnergyAtomicModel, | ||
) | ||
from .ener_model import ( | ||
EnergyModel, | ||
) | ||
from .polar_model import ( | ||
PolarModel, | ||
) | ||
from .property_model import ( | ||
PropertyModel, | ||
) | ||
|
||
__all__ = [ | ||
"EnergyModel", | ||
"DPZBLLinearEnergyAtomicModel", | ||
"DOSModel", | ||
"DipoleModel", | ||
"PolarModel", | ||
"PropertyModel", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
||
from deepmd.dpmodel.model.dipole_model import DipoleModel as DipoleModelDP | ||
from deepmd.jax.atomic_model.dipole_atomic_model import ( | ||
DPAtomicModelDipole, | ||
) | ||
from deepmd.jax.model.base_model import ( | ||
BaseModel, | ||
) | ||
from deepmd.jax.model.dp_model import ( | ||
make_jax_dp_model_from_dpmodel, | ||
) | ||
|
||
|
||
@BaseModel.register("dipole") | ||
class DipoleModel(make_jax_dp_model_from_dpmodel(DipoleModelDP, DPAtomicModelDipole)): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from deepmd.dpmodel.model.dos_model import DOSModel as DOSModelDP | ||
from deepmd.jax.atomic_model.dos_atomic_model import ( | ||
DPAtomicModelDOS, | ||
) | ||
from deepmd.jax.model.base_model import ( | ||
BaseModel, | ||
) | ||
from deepmd.jax.model.dp_model import ( | ||
make_jax_dp_model_from_dpmodel, | ||
) | ||
|
||
|
||
@BaseModel.register("dos") | ||
class DOSModel(make_jax_dp_model_from_dpmodel(DOSModelDP, DPAtomicModelDOS)): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from typing import ( | ||
Any, | ||
Optional, | ||
) | ||
|
||
from deepmd.dpmodel.model import ( | ||
DPModelCommon, | ||
) | ||
from deepmd.jax.atomic_model.dp_atomic_model import ( | ||
DPAtomicModel, | ||
) | ||
from deepmd.jax.common import ( | ||
flax_module, | ||
) | ||
from deepmd.jax.env import ( | ||
jax, | ||
jnp, | ||
) | ||
from deepmd.jax.model.base_model import ( | ||
forward_common_atomic, | ||
) | ||
|
||
|
||
def make_jax_dp_model_from_dpmodel( | ||
dpmodel_model: type[DPModelCommon], jax_atomicmodel: type[DPAtomicModel] | ||
) -> type[DPModelCommon]: | ||
"""Make a JAX backend DP model from a DPModel backend DP model. | ||
Parameters | ||
---------- | ||
dpmodel_model : type[DPModelCommon] | ||
The DPModel backend DP model. | ||
jax_atomicmodel : type[DPAtomicModel] | ||
The JAX backend DP atomic model. | ||
Returns | ||
------- | ||
type[DPModelCommon] | ||
The JAX backend DP model. | ||
""" | ||
|
||
@flax_module | ||
class jax_model(dpmodel_model): | ||
def __setattr__(self, name: str, value: Any) -> None: | ||
if name == "atomic_model": | ||
value = jax_atomicmodel.deserialize(value.serialize()) | ||
return super().__setattr__(name, value) | ||
|
||
def forward_common_atomic( | ||
self, | ||
extended_coord: jnp.ndarray, | ||
extended_atype: jnp.ndarray, | ||
nlist: jnp.ndarray, | ||
mapping: Optional[jnp.ndarray] = None, | ||
fparam: Optional[jnp.ndarray] = None, | ||
aparam: Optional[jnp.ndarray] = None, | ||
do_atomic_virial: bool = False, | ||
): | ||
return forward_common_atomic( | ||
self, | ||
extended_coord, | ||
extended_atype, | ||
nlist, | ||
mapping=mapping, | ||
fparam=fparam, | ||
aparam=aparam, | ||
do_atomic_virial=do_atomic_virial, | ||
) | ||
|
||
def format_nlist( | ||
self, | ||
extended_coord: jnp.ndarray, | ||
extended_atype: jnp.ndarray, | ||
nlist: jnp.ndarray, | ||
extra_nlist_sort: bool = False, | ||
): | ||
return dpmodel_model.format_nlist( | ||
self, | ||
jax.lax.stop_gradient(extended_coord), | ||
extended_atype, | ||
nlist, | ||
extra_nlist_sort=extra_nlist_sort, | ||
) | ||
|
||
return jax_model |
Oops, something went wrong.