Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Nov 8, 2024
1 parent fae3404 commit 554be6f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
6 changes: 5 additions & 1 deletion deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Optional,
)

import array_api_compat
import ml_dtypes
import numpy as np

Expand Down Expand Up @@ -109,7 +110,10 @@ def to_numpy_array(x: Any) -> Optional[np.ndarray]:
# asarray is not within Array API standard, so may fail
return np.asarray(x)
except (ValueError, AttributeError):
return np.from_dlpack(x, copy=True)
xp = array_api_compat.array_namespace(x)
# to fix BufferError: Cannot export readonly array since signalling readonly is unsupported by DLPack.
x = xp.asarray(x, copy=True)
return np.from_dlpack(x)


__all__ = [
Expand Down
11 changes: 7 additions & 4 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from deepmd.dpmodel.array_api import (
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils import (
EmbeddingNet,
EnvMat,
Expand Down Expand Up @@ -548,8 +551,8 @@ def serialize(self) -> dict:
"exclude_types": obj.exclude_types,
"env_protection": obj.env_protection,
"@variables": {
"davg": np.array(obj["davg"]),
"dstd": np.array(obj["dstd"]),
"davg": to_numpy_array(obj["davg"]),
"dstd": to_numpy_array(obj["dstd"]),
},
## to be updated when the options are supported.
"trainable": self.trainable,
Expand Down Expand Up @@ -1022,8 +1025,8 @@ def serialize(self) -> dict:
"exclude_types": obj.exclude_types,
"env_protection": obj.env_protection,
"@variables": {
"davg": np.array(obj["davg"]),
"dstd": np.array(obj["dstd"]),
"davg": to_numpy_array(obj["davg"]),
"dstd": to_numpy_array(obj["dstd"]),
},
}
if obj.tebd_input_mode in ["strip"]:
Expand Down

0 comments on commit 554be6f

Please sign in to comment.