Skip to content

Commit 256a92b

Browse files
committed
perf: use multithread to accelarate stat computing and loading
1 parent 4b9498b commit 256a92b

File tree

2 files changed

+110
-49
lines changed

2 files changed

+110
-49
lines changed

deepmd/pt/utils/stat.py

Lines changed: 78 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from collections import (
44
defaultdict,
55
)
6+
from concurrent.futures import (
7+
ThreadPoolExecutor,
8+
)
69
from typing import (
710
Any,
811
Callable,
@@ -39,7 +42,7 @@
3942
def make_stat_input(
4043
datasets: list[Any], dataloaders: list[Any], nbatches: int
4144
) -> dict[str, Any]:
42-
"""Pack data for statistics.
45+
"""Pack data for statistics in parallel.
4346
4447
Args:
4548
- dataset: A list of dataset to analyze.
@@ -49,49 +52,83 @@ def make_stat_input(
4952
-------
5053
- a list of dicts, each of which contains data from a system
5154
"""
52-
lst = []
5355
log.info(f"Packing data for statistics from {len(datasets)} systems")
54-
for i in range(len(datasets)):
55-
sys_stat = {}
56-
with torch.device("cpu"):
57-
iterator = iter(dataloaders[i])
58-
numb_batches = min(nbatches, len(dataloaders[i]))
59-
for _ in range(numb_batches):
60-
try:
61-
stat_data = next(iterator)
62-
except StopIteration:
63-
iterator = iter(dataloaders[i])
64-
stat_data = next(iterator)
65-
if (
66-
"find_fparam" in stat_data
67-
and "fparam" in stat_data
68-
and stat_data["find_fparam"] == 0.0
69-
):
70-
# for model using default fparam
71-
stat_data.pop("fparam")
72-
stat_data.pop("find_fparam")
73-
for dd in stat_data:
74-
if stat_data[dd] is None:
75-
sys_stat[dd] = None
76-
elif isinstance(stat_data[dd], torch.Tensor):
77-
if dd not in sys_stat:
78-
sys_stat[dd] = []
79-
sys_stat[dd].append(stat_data[dd])
80-
elif isinstance(stat_data[dd], np.float32):
81-
sys_stat[dd] = stat_data[dd]
82-
else:
83-
pass
84-
85-
for key in sys_stat:
86-
if isinstance(sys_stat[key], np.float32):
87-
pass
88-
elif sys_stat[key] is None or sys_stat[key][0] is None:
56+
dataloader_lens = [len(dl) for dl in dataloaders]
57+
args_list = [
58+
(dataloaders[i], nbatches, dataloader_lens[i]) for i in range(len(datasets))
59+
]
60+
61+
lst = []
62+
# I/O intensive, set a larger number of workers
63+
with ThreadPoolExecutor(max_workers=256) as executor:
64+
lst = list(executor.map(_process_one_dataset, args_list))
65+
log.info("Finished packing data.")
66+
return lst
67+
68+
69+
def _process_one_dataset(args: tuple[Any, int, int]) -> dict[str, Any]:
70+
"""
71+
Helper function to process a single dataset's dataloader for statistics.
72+
Designed to be called in parallel by a ThreadPoolExecutor.
73+
74+
Parameters
75+
----------
76+
args : tuple(Any, int, int)
77+
A tuple containing (dataloader, nbatches, dataloader_len)
78+
79+
Returns
80+
-------
81+
dict[str, Any]
82+
The processed sys_stat dictionary for one dataset.
83+
"""
84+
dataloader, nbatches, dataloader_len = args
85+
sys_stat = {}
86+
87+
with torch.device("cpu"):
88+
iterator = iter(dataloader)
89+
numb_batches = min(nbatches, dataloader_len)
90+
91+
for _ in range(numb_batches):
92+
try:
93+
stat_data = next(iterator)
94+
except StopIteration:
95+
iterator = iter(dataloader)
96+
stat_data = next(iterator)
97+
98+
if (
99+
"find_fparam" in stat_data
100+
and "fparam" in stat_data
101+
and stat_data["find_fparam"] == 0.0
102+
):
103+
# for model using default fparam
104+
stat_data.pop("fparam")
105+
stat_data.pop("find_fparam")
106+
107+
for dd in stat_data:
108+
if stat_data[dd] is None:
109+
sys_stat[dd] = None
110+
elif isinstance(stat_data[dd], torch.Tensor):
111+
if dd not in sys_stat:
112+
sys_stat[dd] = []
113+
sys_stat[dd].append(stat_data[dd])
114+
elif isinstance(stat_data[dd], np.float32):
115+
sys_stat[dd] = stat_data[dd]
116+
else:
117+
pass
118+
119+
for key in sys_stat:
120+
if isinstance(sys_stat[key], np.float32):
121+
pass
122+
elif isinstance(sys_stat[key], list):
123+
if sys_stat[key][0] is None:
89124
sys_stat[key] = None
90-
elif isinstance(stat_data[dd], torch.Tensor):
125+
else:
91126
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
92-
dict_to_device(sys_stat)
93-
lst.append(sys_stat)
94-
return lst
127+
elif sys_stat[key] is None:
128+
pass
129+
130+
dict_to_device(sys_stat)
131+
return sys_stat
95132

96133

97134
def _restore_from_file(

deepmd/utils/env_mat_stat.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from collections.abc import (
1111
Iterator,
1212
)
13+
from concurrent.futures import (
14+
ThreadPoolExecutor,
15+
)
1316
from typing import (
1417
Optional,
1518
)
@@ -142,7 +145,7 @@ def save_stats(self, path: DPPath) -> None:
142145
(path / kk).save_numpy(np.array([vv.number, vv.sum, vv.squared_sum]))
143146

144147
def load_stats(self, path: DPPath) -> None:
145-
"""Load the statistics of the environment matrix.
148+
"""Load the statistics of the environment matrix in parallel.
146149
147150
Parameters
148151
----------
@@ -151,13 +154,18 @@ def load_stats(self, path: DPPath) -> None:
151154
"""
152155
if len(self.stats) > 0:
153156
raise ValueError("The statistics has already been computed.")
154-
for kk in path.glob("*"):
155-
arr = kk.load_numpy()
156-
self.stats[kk.name] = StatItem(
157-
number=arr[0],
158-
sum=arr[1],
159-
squared_sum=arr[2],
160-
)
157+
158+
files_to_load = list(path.glob("*"))
159+
160+
if not files_to_load:
161+
raise ValueError(f"No statistics files found in {path}.")
162+
163+
with ThreadPoolExecutor(max_workers=128) as executor:
164+
results = executor.map(self._load_stat_file, files_to_load)
165+
166+
for name, stat_item in results:
167+
if stat_item is not None:
168+
self.stats[name] = stat_item
161169

162170
def load_or_compute_stats(
163171
self, data: list[dict[str, np.ndarray]], path: Optional[DPPath] = None
@@ -216,3 +224,19 @@ def get_std(
216224
kk: vv.compute_std(default=default, protection=protection)
217225
for kk, vv in self.stats.items()
218226
}
227+
228+
@staticmethod
229+
def _load_stat_file(file_path: DPPath) -> tuple[str, StatItem]:
230+
"""Helper function for parallel loading of stat files."""
231+
try:
232+
arr = file_path.load_numpy()
233+
if arr.shape == (3,):
234+
return file_path.name, StatItem(
235+
number=arr[0], sum=arr[1], squared_sum=arr[2]
236+
)
237+
else:
238+
log.warning(f"Skipping malformed stat file: {file_path.name}")
239+
return file_path.name, None
240+
except Exception as e:
241+
log.warning(f"Failed to load stat file {file_path.name}: {e}")
242+
return file_path.name, None

0 commit comments

Comments
 (0)