diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index 2925b0ac32..28e29cdcb7 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -8,12 +8,12 @@ from deepmd.dpmodel.descriptor.base_descriptor import ( BaseDescriptor, ) +from deepmd.dpmodel.fitting.dos_fitting import ( + DOSFittingNet, +) from deepmd.dpmodel.fitting.ener_fitting import ( EnergyFittingNet, ) -from deepmd.dpmodel.fitting.dos_fitting import ( - DOSFittingNet -) from deepmd.dpmodel.model.base_model import ( BaseModel, ) diff --git a/source/tests/consistent/model/common.py b/source/tests/consistent/model/common.py index 9dbc69ba91..ef1c7cf911 100644 --- a/source/tests/consistent/model/common.py +++ b/source/tests/consistent/model/common.py @@ -34,7 +34,9 @@ class ModelTest: """Useful utilities for model tests.""" - def build_tf_model(self, obj, natoms, coords, atype, box, suffix, ret_key:str="energy"): + def build_tf_model( + self, obj, natoms, coords, atype, box, suffix, ret_key: str = "energy" + ): t_coord = tf.placeholder( GLOBAL_TF_FLOAT_PRECISION, [None, None, None], name="i_coord" ) diff --git a/source/tests/consistent/model/test_dos.py b/source/tests/consistent/model/test_dos.py index 5239c68ed7..8f0b0309cc 100644 --- a/source/tests/consistent/model/test_dos.py +++ b/source/tests/consistent/model/test_dos.py @@ -82,7 +82,7 @@ def get_reference_backend(self): @property def skip_tf(self): - return True # need to fix tf consistency + return True # need to fix tf consistency @property def skip_jax(self) -> bool: @@ -140,13 +140,7 @@ def setUp(self) -> None: def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: return self.build_tf_model( - obj, - self.natoms, - self.coords, - self.atype, - self.box, - suffix, - ret_key = "dos" + obj, self.natoms, self.coords, self.atype, self.box, suffix, ret_key="dos" ) def eval_dp(self, dp_obj: Any) -> Any: