diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index 101badaab5..f5462ec4d0 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -24,7 +24,10 @@ "int32": np.int32, "int64": np.int64, "default": GLOBAL_NP_FLOAT_PRECISION, - # NumPy doesn't have bfloat16 (and does't plan to add). Use float32 as a substitute. + # NumPy doesn't have bfloat16 (and does't plan to add) + # ml_dtypes is a solution, but it seems not supporting np.save/np.load + # hdf5 hasn't supported bfloat16 as well (see https://forum.hdfgroup.org/t/11975) + # Use float32 as a substitute. "bfloat16": np.float32, } assert VALID_PRECISION.issubset(PRECISION_DICT.keys()) diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 3337036ca9..68efc3cc66 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -10,6 +10,7 @@ import torch.nn.functional as F from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT +from deepmd.dpmodel.common import RESERVED_PRECISON_DICT as NP_RESERVED_PRECISON_DICT from .env import ( DEVICE, @@ -103,7 +104,9 @@ def to_torch_tensor( return None assert xx is not None # Create a reverse mapping of NP_PRECISION_DICT - reverse_precision_dict = {v: k for k, v in NP_PRECISION_DICT.items()} + # unsafe considering bfloat16: + # reverse_precision_dict = {v: k for k, v in NP_PRECISION_DICT.items()} + reverse_precision_dict = NP_RESERVED_PRECISON_DICT # Use the reverse mapping to find keys with the desired value prec = reverse_precision_dict.get(xx.dtype.type, None) prec = PT_PRECISION_DICT.get(prec, None)