From c80942f0815b57eb37222dbf1d35595e217bbe7d Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 7 Nov 2024 21:36:44 -0500 Subject: [PATCH] fix tests Signed-off-by: Jinzhe Zeng --- source/tests/consistent/fitting/test_dipole.py | 5 ++++- source/tests/consistent/fitting/test_dos.py | 5 ++++- source/tests/consistent/fitting/test_ener.py | 5 ++++- source/tests/consistent/fitting/test_polar.py | 5 ++++- source/tests/consistent/fitting/test_property.py | 5 ++++- 5 files changed, 20 insertions(+), 5 deletions(-) diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index 60ee7322c1..088cb30238 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -6,6 +6,9 @@ import numpy as np +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingDP from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, @@ -175,7 +178,7 @@ def eval_jax(self, jax_obj: Any) -> Any: ) def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: - return np.asarray( + return to_numpy_array( array_api_strict_obj( array_api_strict.asarray(self.inputs), array_api_strict.asarray(self.atype.reshape(1, -1)), diff --git a/source/tests/consistent/fitting/test_dos.py b/source/tests/consistent/fitting/test_dos.py index d3de3ef151..ce7905585d 100644 --- a/source/tests/consistent/fitting/test_dos.py +++ b/source/tests/consistent/fitting/test_dos.py @@ -6,6 +6,9 @@ import numpy as np +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingDP from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, @@ -227,7 +230,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: numb_aparam, numb_dos, ) = self.param - return np.asarray( + return to_numpy_array( array_api_strict_obj( array_api_strict.asarray(self.inputs), array_api_strict.asarray(self.atype.reshape(1, -1)), diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index f4e78ce966..8bed595733 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -6,6 +6,9 @@ import numpy as np +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnerFittingDP from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, @@ -241,7 +244,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: (numb_aparam, use_aparam_as_mask), atom_ener, ) = self.param - return np.asarray( + return to_numpy_array( array_api_strict_obj( array_api_strict.asarray(self.inputs), array_api_strict.asarray(self.atype.reshape(1, -1)), diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index bd9d013b8d..12f13d1e08 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -6,6 +6,9 @@ import numpy as np +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.fitting.polarizability_fitting import PolarFitting as PolarFittingDP from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, @@ -175,7 +178,7 @@ def eval_jax(self, jax_obj: Any) -> Any: ) def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: - return np.asarray( + return to_numpy_array( array_api_strict_obj( array_api_strict.asarray(self.inputs), array_api_strict.asarray(self.atype.reshape(1, -1)), diff --git a/source/tests/consistent/fitting/test_property.py b/source/tests/consistent/fitting/test_property.py index a096d4dd68..b83c12f581 100644 --- a/source/tests/consistent/fitting/test_property.py +++ b/source/tests/consistent/fitting/test_property.py @@ -6,6 +6,9 @@ import numpy as np +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.fitting.property_fitting import ( PropertyFittingNet as PropertyFittingDP, ) @@ -236,7 +239,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: task_dim, intensive, ) = self.param - return np.asarray( + return to_numpy_array( array_api_strict_obj( array_api_strict.asarray(self.inputs), array_api_strict.asarray(self.atype.reshape(1, -1)),