Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Nov 8, 2024
1 parent 554be6f commit c80942f
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 5 deletions.
5 changes: 4 additions & 1 deletion source/tests/consistent/fitting/test_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingDP
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
Expand Down Expand Up @@ -175,7 +178,7 @@ def eval_jax(self, jax_obj: Any) -> Any:
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return np.asarray(
return to_numpy_array(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
Expand Down
5 changes: 4 additions & 1 deletion source/tests/consistent/fitting/test_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingDP
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
Expand Down Expand Up @@ -227,7 +230,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
numb_aparam,
numb_dos,
) = self.param
return np.asarray(
return to_numpy_array(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
Expand Down
5 changes: 4 additions & 1 deletion source/tests/consistent/fitting/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnerFittingDP
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
Expand Down Expand Up @@ -241,7 +244,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
(numb_aparam, use_aparam_as_mask),
atom_ener,
) = self.param
return np.asarray(
return to_numpy_array(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
Expand Down
5 changes: 4 additions & 1 deletion source/tests/consistent/fitting/test_polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.fitting.polarizability_fitting import PolarFitting as PolarFittingDP
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
Expand Down Expand Up @@ -175,7 +178,7 @@ def eval_jax(self, jax_obj: Any) -> Any:
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return np.asarray(
return to_numpy_array(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
Expand Down
5 changes: 4 additions & 1 deletion source/tests/consistent/fitting/test_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.fitting.property_fitting import (
PropertyFittingNet as PropertyFittingDP,
)
Expand Down Expand Up @@ -236,7 +239,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
task_dim,
intensive,
) = self.param
return np.asarray(
return to_numpy_array(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
Expand Down

0 comments on commit c80942f

Please sign in to comment.