Skip to content

Commit e6d4638

Browse files
committed
perf: avoid repeated _get_nframes
1 parent d877e5b commit e6d4638

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

deepmd/utils/data.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,12 @@ def get_single_frame(self, index: int) -> dict:
385385
set_dir = DPPath(set_dir)
386386
# Calculate local index within the set.* directory
387387
local_idx = index - (0 if set_idx == 0 else self.prefix_sum[set_idx - 1])
388+
# Calculate the number of frames in this set to avoid redundant _get_nframes calls
389+
set_nframes = (
390+
self.prefix_sum[set_idx]
391+
if set_idx == 0
392+
else self.prefix_sum[set_idx] - self.prefix_sum[set_idx - 1]
393+
)
388394

389395
frame_data = {}
390396
# 2. Concurrently load all non-reduced items
@@ -395,7 +401,7 @@ def get_single_frame(self, index: int) -> dict:
395401
with ThreadPoolExecutor(max_workers=len(non_reduced_keys)) as executor:
396402
future_to_key = {
397403
executor.submit(
398-
self._load_single_data, set_dir, key, local_idx
404+
self._load_single_data, set_dir, key, local_idx, set_nframes
399405
): key
400406
for key in non_reduced_keys
401407
}
@@ -805,11 +811,22 @@ def _load_data(
805811
return np.float32(0.0), data
806812

807813
def _load_single_data(
808-
self, set_dir: DPPath, key: str, frame_idx: int
814+
self, set_dir: DPPath, key: str, frame_idx: int, set_nframes: int
809815
) -> tuple[np.float32, np.ndarray]:
810816
"""
811817
Loads and processes data for a SINGLE frame from a SINGLE key,
812818
fully replicating the logic from the original _load_data method.
819+
820+
Parameters
821+
----------
822+
set_dir : DPPath
823+
The directory path of the set
824+
key : str
825+
The key name of the data to load
826+
frame_idx : int
827+
The local frame index within the set
828+
set_nframes : int
829+
The total number of frames in this set (to avoid redundant _get_nframes calls)
813830
"""
814831
vv = self.data_dict[key]
815832
path = set_dir / (key + ".npy")
@@ -859,7 +876,7 @@ def _load_single_data(
859876
# Branch 2: File exists, use memmap
860877
mmap_obj = self._get_memmap(path)
861878
# corner case: single frame
862-
if self._get_nframes(set_dir) == 1:
879+
if set_nframes == 1:
863880
mmap_obj = mmap_obj[None, ...]
864881
# Slice the single frame and make an in-memory copy for modification
865882
data = mmap_obj[frame_idx].copy().astype(dtype, copy=False)

0 commit comments

Comments
 (0)