Skip to content

Commit

Permalink
chore: refactor atomic model fitting assertion
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Nov 1, 2024
1 parent 8e85e59 commit c5dab77
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 12 deletions.
3 changes: 2 additions & 1 deletion deepmd/dpmodel/atomic_model/dipole_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

class DPDipoleAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, DipoleFitting)
if not isinstance(fitting, DipoleFitting):
raise TypeError("fitting must be an instance of DipoleFitting for DPDipoleAtomicModel")
super().__init__(descriptor, fitting, type_map, **kwargs)

def apply_out_stat(
Expand Down
3 changes: 2 additions & 1 deletion deepmd/dpmodel/atomic_model/dos_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@

class DPDOSAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, DOSFittingNet)
if not isinstance(fitting, DOSFittingNet):
raise TypeError("fitting must be an instance of DOSFittingNet for DPDOSAtomicModel")
super().__init__(descriptor, fitting, type_map, **kwargs)
7 changes: 5 additions & 2 deletions deepmd/dpmodel/atomic_model/energy_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

class DPEnergyAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, EnergyFittingNet) or isinstance(
if not (isinstance(fitting, EnergyFittingNet) or isinstance(
fitting, InvarFitting
)
)):
raise TypeError(
"fitting must be an instance of EnergyFittingNet or InvarFitting for DPEnergyAtomicModel"
)
super().__init__(descriptor, fitting, type_map, **kwargs)
3 changes: 2 additions & 1 deletion deepmd/dpmodel/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

class DPPolarAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, PolarFitting)
if not isinstance(fitting, PolarFitting):
raise TypeError("fitting must be an instance of PolarFitting for DPPolarAtomicModel")
super().__init__(descriptor, fitting, type_map, **kwargs)

def apply_out_stat(
Expand Down
3 changes: 2 additions & 1 deletion deepmd/dpmodel/atomic_model/property_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@

class DPPropertyAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, PropertyFittingNet)
if not isinstance(fitting, PropertyFittingNet):
raise TypeError("fitting must be an instance of PropertyFittingNet for DPPropertyAtomicModel")
super().__init__(descriptor, fitting, type_map, **kwargs)
3 changes: 2 additions & 1 deletion deepmd/pt/model/atomic_model/dipole_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

class DPDipoleAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, DipoleFittingNet)
if not isinstance(fitting, DipoleFittingNet):
raise TypeError("fitting must be an instance of DipoleFittingNet for DPDipoleAtomicModel")
super().__init__(descriptor, fitting, type_map, **kwargs)

def apply_out_stat(
Expand Down
3 changes: 2 additions & 1 deletion deepmd/pt/model/atomic_model/dos_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@

class DPDOSAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, DOSFittingNet)
if not isinstance(fitting, DOSFittingNet):
raise TypeError("fitting must be an instance of DOSFittingNet for DPDOSAtomicModel")
super().__init__(descriptor, fitting, type_map, **kwargs)
7 changes: 5 additions & 2 deletions deepmd/pt/model/atomic_model/energy_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@

class DPEnergyAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert (
if not (
isinstance(fitting, EnergyFittingNet)
or isinstance(fitting, EnergyFittingNetDirect)
or isinstance(fitting, InvarFitting)
)
):
raise TypeError(
"fitting must be an instance of EnergyFittingNet, EnergyFittingNetDirect or InvarFitting for DPEnergyAtomicModel"
)
super().__init__(descriptor, fitting, type_map, **kwargs)
3 changes: 2 additions & 1 deletion deepmd/pt/model/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

class DPPolarAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, PolarFittingNet)
if not isinstance(fitting, PolarFittingNet):
raise TypeError("fitting must be an instance of PolarFittingNet for DPPolarAtomicModel")
super().__init__(descriptor, fitting, type_map, **kwargs)

def apply_out_stat(
Expand Down
3 changes: 2 additions & 1 deletion deepmd/pt/model/atomic_model/property_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

class DPPropertyAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, PropertyFittingNet)
if not isinstance(fitting, PropertyFittingNet):
raise TypeError("fitting must be an instance of PropertyFittingNet for DPPropertyAtomicModel")
super().__init__(descriptor, fitting, type_map, **kwargs)

def apply_out_stat(
Expand Down

0 comments on commit c5dab77

Please sign in to comment.