Skip to content

Commit aafccb2

Browse files
committed
use None to add one fake dim for single frame datasets
1 parent 291cbdb commit aafccb2

File tree

1 file changed

+10
-20
lines changed

1 file changed

+10
-20
lines changed

deepmd/utils/data.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -854,28 +854,19 @@ def _load_single_data(
854854

855855
# Branch 2: File exists, use memmap
856856
mmap_obj = self._get_memmap(path)
857+
# corner case: single frame
858+
if self._get_nframes(set_dir) == 1:
859+
mmap_obj = mmap_obj[None, ...]
857860
# Slice the single frame and make an in-memory copy for modification
858-
if mmap_obj.ndim == 0:
859-
# case: single frame data && non-atomic
860-
data = mmap_obj.copy().astype(dtype, copy=False).reshape(1, -1)
861-
elif mmap_obj.ndim == 1:
862-
# case: single frame data && atomic
863-
if frame_idx != 0:
864-
raise IndexError(
865-
f"frame index {frame_idx} out of range for single-frame file: {path}"
866-
)
867-
data = mmap_obj.copy().astype(dtype, copy=False).reshape(1, -1)
868-
else:
869-
# case: multi-frame data
870-
data = mmap_obj[frame_idx].copy().astype(dtype, copy=False).reshape(1, -1)
861+
data = mmap_obj[frame_idx].copy().astype(dtype, copy=False)
871862

872863
try:
873864
if vv["atomic"]:
874865
# Handle type_sel logic
875866
if vv["type_sel"] is not None:
876867
sel_mask = np.isin(self.atom_type, vv["type_sel"])
877868

878-
if data.shape[1] == natoms_sel * ndof:
869+
if mmap_obj.shape[1] == natoms_sel * ndof:
879870
if vv["output_natoms_for_type_sel"]:
880871
tmp = np.zeros([natoms, ndof], dtype=data.dtype)
881872
# sel_mask needs to be applied to the original atom layout
@@ -884,7 +875,7 @@ def _load_single_data(
884875
else: # output is natoms_sel
885876
natoms = natoms_sel
886877
idx_map = idx_map_sel
887-
elif data.shape[1] == natoms * ndof:
878+
elif mmap_obj.shape[1] == natoms * ndof:
888879
data = data.reshape([natoms, ndof])
889880
if vv["output_natoms_for_type_sel"]:
890881
pass
@@ -894,7 +885,7 @@ def _load_single_data(
894885
natoms = natoms_sel
895886
else: # Shape mismatch error
896887
raise ValueError(
897-
f"The shape of the data {key} in {set_dir} has width {data.shape[1]}, which doesn't match either ({natoms_sel * ndof}) or ({natoms * ndof})"
888+
f"The shape of the data {key} in {set_dir} has width {mmap_obj.shape[1]}, which doesn't match either ({natoms_sel * ndof}) or ({natoms * ndof})"
898889
)
899890

900891
# Handle special case for Hessian
@@ -911,14 +902,13 @@ def _load_single_data(
911902
# size of hessian is 3Natoms * 3Natoms
912903
# ndof = 3 * ndof * 3 * ndof
913904
else:
914-
# data should be 2D here: (natoms, ndof)
905+
# data should be 2D here: [natoms, ndof]
915906
data = data.reshape([natoms, -1])
916907
data = data[idx_map, :]
917-
# Atomic: return [natoms, ndof] or flattened hessian above
918-
return np.float32(1.0), data
919908

909+
# Atomic: return [natoms, ndof] or flattened hessian above
920910
# Non-atomic: return [ndof]
921-
return np.float32(1.0), data.reshape([ndof])
911+
return np.float32(1.0), data
922912

923913
except ValueError as err_message:
924914
explanation = (

0 commit comments

Comments
 (0)