Skip to content

Commit 97ddb72

Browse files
committed
bug fix & simplify
1 parent b460dc1 commit 97ddb72

File tree

3 files changed

+14
-15
lines changed

3 files changed

+14
-15
lines changed

deepmd/env.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,23 @@
5050
)
5151

5252
# Dynamic calculation of cache size
53-
DEFAULT_LRU_CACHE_SIZE = 888
54-
LRU_CACHE_SIZE = DEFAULT_LRU_CACHE_SIZE
53+
_default_lru_cache_size = 888
54+
LRU_CACHE_SIZE = _default_lru_cache_size
5555

5656
if platform.system() != "Windows":
5757
try:
5858
import resource
5959

6060
soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
6161
safe_buffer = 128
62-
if soft_limit > safe_buffer + DEFAULT_LRU_CACHE_SIZE:
62+
if soft_limit > safe_buffer + _default_lru_cache_size:
6363
LRU_CACHE_SIZE = soft_limit - safe_buffer
6464
else:
6565
LRU_CACHE_SIZE = soft_limit // 2
6666
except ImportError:
67-
LRU_CACHE_SIZE = DEFAULT_LRU_CACHE_SIZE
67+
LRU_CACHE_SIZE = _default_lru_cache_size
6868
else:
69-
LRU_CACHE_SIZE = DEFAULT_LRU_CACHE_SIZE
69+
LRU_CACHE_SIZE = _default_lru_cache_size
7070

7171

7272
def set_env_if_empty(key: str, value: str, verbose: bool = True) -> None:

deepmd/pd/utils/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __len__(self):
3636

3737
def __getitem__(self, index):
3838
"""Get a frame from the selected system."""
39-
b_data = self._data_system.get_frame_paddle(index)
39+
b_data = self._data_system.get_item_paddle(index)
4040
b_data["natoms"] = self._natoms_vec
4141
return b_data
4242

deepmd/utils/data.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,6 @@ def __init__(
134134
self.shuffle_test = shuffle_test
135135
# set modifier
136136
self.modifier = modifier
137-
# Add a cache for memory-mapped files
138-
self.memmap_cache = {}
139137
# calculate prefix sum for get_item method
140138
frames_list = [self._get_nframes(item) for item in self.dirs]
141139
self.nframes = np.sum(frames_list)
@@ -451,11 +449,12 @@ def get_single_frame(self, index: int) -> dict:
451449

452450
# 6. Reshape atomic data to match expected format [natoms, ndof]
453451
for kk in self.data_dict.keys():
454-
if "find_" in kk:
455-
pass
456-
else:
457-
if kk in frame_data and not self.data_dict[kk]["atomic"]:
458-
frame_data[kk] = frame_data[kk].reshape(-1)
452+
if (
453+
"find_" not in kk
454+
and kk in frame_data
455+
and not self.data_dict[kk]["atomic"]
456+
):
457+
frame_data[kk] = frame_data[kk].reshape(-1)
459458
frame_data["atype"] = frame_data["type"]
460459

461460
if not self.pbc:
@@ -782,7 +781,7 @@ def _load_data(
782781
data = data.reshape([nframes, -1])
783782
data = np.reshape(data, [nframes, ndof])
784783
except ValueError as err_message:
785-
explanation = "This error may occur when your label mismatch it's name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`."
784+
explanation = "This error may occur when your label mismatch its name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`."
786785
log.error(str(err_message))
787786
log.error(explanation)
788787
raise ValueError(str(err_message) + ". " + explanation) from err_message
@@ -904,7 +903,7 @@ def _load_single_data(
904903
return np.float32(1.0), data
905904

906905
except ValueError as err_message:
907-
explanation = "This error may occur when your label mismatch it's name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`."
906+
explanation = "This error may occur when your label mismatch its name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`."
908907
log.error(str(err_message))
909908
log.error(explanation)
910909
raise ValueError(str(err_message) + ". " + explanation) from err_message

0 commit comments

Comments
 (0)