From c5dab77e53f3fe6b0920b2a9e40124dd529d053a Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Fri, 1 Nov 2024 17:19:21 +0800 Subject: [PATCH] chore: refactor atomic model fitting assertion --- deepmd/dpmodel/atomic_model/dipole_atomic_model.py | 3 ++- deepmd/dpmodel/atomic_model/dos_atomic_model.py | 3 ++- deepmd/dpmodel/atomic_model/energy_atomic_model.py | 7 +++++-- deepmd/dpmodel/atomic_model/polar_atomic_model.py | 3 ++- deepmd/dpmodel/atomic_model/property_atomic_model.py | 3 ++- deepmd/pt/model/atomic_model/dipole_atomic_model.py | 3 ++- deepmd/pt/model/atomic_model/dos_atomic_model.py | 3 ++- deepmd/pt/model/atomic_model/energy_atomic_model.py | 7 +++++-- deepmd/pt/model/atomic_model/polar_atomic_model.py | 3 ++- deepmd/pt/model/atomic_model/property_atomic_model.py | 3 ++- 10 files changed, 26 insertions(+), 12 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/dipole_atomic_model.py b/deepmd/dpmodel/atomic_model/dipole_atomic_model.py index 6b7b9f470b..263dba435e 100644 --- a/deepmd/dpmodel/atomic_model/dipole_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dipole_atomic_model.py @@ -12,7 +12,8 @@ class DPDipoleAtomicModel(DPAtomicModel): def __init__(self, descriptor, fitting, type_map, **kwargs): - assert isinstance(fitting, DipoleFitting) + if not isinstance(fitting, DipoleFitting): + raise TypeError("fitting must be an instance of DipoleFitting for DPDipoleAtomicModel") super().__init__(descriptor, fitting, type_map, **kwargs) def apply_out_stat( diff --git a/deepmd/dpmodel/atomic_model/dos_atomic_model.py b/deepmd/dpmodel/atomic_model/dos_atomic_model.py index fc584bcb56..8cb551313b 100644 --- a/deepmd/dpmodel/atomic_model/dos_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dos_atomic_model.py @@ -10,5 +10,6 @@ class DPDOSAtomicModel(DPAtomicModel): def __init__(self, descriptor, fitting, type_map, **kwargs): - assert isinstance(fitting, DOSFittingNet) + if not isinstance(fitting, DOSFittingNet): + raise TypeError("fitting must be an instance of DOSFittingNet for DPDOSAtomicModel") super().__init__(descriptor, fitting, type_map, **kwargs) diff --git a/deepmd/dpmodel/atomic_model/energy_atomic_model.py b/deepmd/dpmodel/atomic_model/energy_atomic_model.py index ad00e1c2cb..a274b09ed4 100644 --- a/deepmd/dpmodel/atomic_model/energy_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/energy_atomic_model.py @@ -11,7 +11,10 @@ class DPEnergyAtomicModel(DPAtomicModel): def __init__(self, descriptor, fitting, type_map, **kwargs): - assert isinstance(fitting, EnergyFittingNet) or isinstance( + if not (isinstance(fitting, EnergyFittingNet) or isinstance( fitting, InvarFitting - ) + )): + raise TypeError( + "fitting must be an instance of EnergyFittingNet or InvarFitting for DPEnergyAtomicModel" + ) super().__init__(descriptor, fitting, type_map, **kwargs) diff --git a/deepmd/dpmodel/atomic_model/polar_atomic_model.py b/deepmd/dpmodel/atomic_model/polar_atomic_model.py index a049491685..cea5ad2c54 100644 --- a/deepmd/dpmodel/atomic_model/polar_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/polar_atomic_model.py @@ -13,7 +13,8 @@ class DPPolarAtomicModel(DPAtomicModel): def __init__(self, descriptor, fitting, type_map, **kwargs): - assert isinstance(fitting, PolarFitting) + if not isinstance(fitting, PolarFitting): + raise TypeError("fitting must be an instance of PolarFitting for DPPolarAtomicModel") super().__init__(descriptor, fitting, type_map, **kwargs) def apply_out_stat( diff --git a/deepmd/dpmodel/atomic_model/property_atomic_model.py b/deepmd/dpmodel/atomic_model/property_atomic_model.py index ecc450bcd2..0cfd24ed91 100644 --- a/deepmd/dpmodel/atomic_model/property_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/property_atomic_model.py @@ -10,5 +10,6 @@ class DPPropertyAtomicModel(DPAtomicModel): def __init__(self, descriptor, fitting, type_map, **kwargs): - assert isinstance(fitting, PropertyFittingNet) + if not isinstance(fitting, PropertyFittingNet): + raise TypeError("fitting must be an instance of PropertyFittingNet for DPPropertyAtomicModel") super().__init__(descriptor, fitting, type_map, **kwargs) diff --git a/deepmd/pt/model/atomic_model/dipole_atomic_model.py b/deepmd/pt/model/atomic_model/dipole_atomic_model.py index aa28294cc5..e980bd9c28 100644 --- a/deepmd/pt/model/atomic_model/dipole_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dipole_atomic_model.py @@ -13,7 +13,8 @@ class DPDipoleAtomicModel(DPAtomicModel): def __init__(self, descriptor, fitting, type_map, **kwargs): - assert isinstance(fitting, DipoleFittingNet) + if not isinstance(fitting, DipoleFittingNet): + raise TypeError("fitting must be an instance of DipoleFittingNet for DPDipoleAtomicModel") super().__init__(descriptor, fitting, type_map, **kwargs) def apply_out_stat( diff --git a/deepmd/pt/model/atomic_model/dos_atomic_model.py b/deepmd/pt/model/atomic_model/dos_atomic_model.py index 5e399f2aff..8ec292f3ec 100644 --- a/deepmd/pt/model/atomic_model/dos_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dos_atomic_model.py @@ -10,5 +10,6 @@ class DPDOSAtomicModel(DPAtomicModel): def __init__(self, descriptor, fitting, type_map, **kwargs): - assert isinstance(fitting, DOSFittingNet) + if not isinstance(fitting, DOSFittingNet): + raise TypeError("fitting must be an instance of DOSFittingNet for DPDOSAtomicModel") super().__init__(descriptor, fitting, type_map, **kwargs) diff --git a/deepmd/pt/model/atomic_model/energy_atomic_model.py b/deepmd/pt/model/atomic_model/energy_atomic_model.py index 7cedaa1ab3..6d894b4aab 100644 --- a/deepmd/pt/model/atomic_model/energy_atomic_model.py +++ b/deepmd/pt/model/atomic_model/energy_atomic_model.py @@ -12,9 +12,12 @@ class DPEnergyAtomicModel(DPAtomicModel): def __init__(self, descriptor, fitting, type_map, **kwargs): - assert ( + if not ( isinstance(fitting, EnergyFittingNet) or isinstance(fitting, EnergyFittingNetDirect) or isinstance(fitting, InvarFitting) - ) + ): + raise TypeError( + "fitting must be an instance of EnergyFittingNet, EnergyFittingNetDirect or InvarFitting for DPEnergyAtomicModel" + ) super().__init__(descriptor, fitting, type_map, **kwargs) diff --git a/deepmd/pt/model/atomic_model/polar_atomic_model.py b/deepmd/pt/model/atomic_model/polar_atomic_model.py index 39cda2650d..e5854baa3f 100644 --- a/deepmd/pt/model/atomic_model/polar_atomic_model.py +++ b/deepmd/pt/model/atomic_model/polar_atomic_model.py @@ -13,7 +13,8 @@ class DPPolarAtomicModel(DPAtomicModel): def __init__(self, descriptor, fitting, type_map, **kwargs): - assert isinstance(fitting, PolarFittingNet) + if not isinstance(fitting, PolarFittingNet): + raise TypeError("fitting must be an instance of PolarFittingNet for DPPolarAtomicModel") super().__init__(descriptor, fitting, type_map, **kwargs) def apply_out_stat( diff --git a/deepmd/pt/model/atomic_model/property_atomic_model.py b/deepmd/pt/model/atomic_model/property_atomic_model.py index 2fac90100f..fb1cc498f8 100644 --- a/deepmd/pt/model/atomic_model/property_atomic_model.py +++ b/deepmd/pt/model/atomic_model/property_atomic_model.py @@ -13,7 +13,8 @@ class DPPropertyAtomicModel(DPAtomicModel): def __init__(self, descriptor, fitting, type_map, **kwargs): - assert isinstance(fitting, PropertyFittingNet) + if not isinstance(fitting, PropertyFittingNet): + raise TypeError("fitting must be an instance of PropertyFittingNet for DPPropertyAtomicModel") super().__init__(descriptor, fitting, type_map, **kwargs) def apply_out_stat(