33from collections import (
44 defaultdict ,
55)
6+ from concurrent .futures import (
7+ ThreadPoolExecutor ,
8+ )
69from typing import (
710 Any ,
811 Callable ,
3942def make_stat_input (
4043 datasets : list [Any ], dataloaders : list [Any ], nbatches : int
4144) -> dict [str , Any ]:
42- """Pack data for statistics.
45+ """Pack data for statistics in parallel .
4346
4447 Args:
4548 - dataset: A list of dataset to analyze.
@@ -49,49 +52,83 @@ def make_stat_input(
4952 -------
5053 - a list of dicts, each of which contains data from a system
5154 """
52- lst = []
5355 log .info (f"Packing data for statistics from { len (datasets )} systems" )
54- for i in range (len (datasets )):
55- sys_stat = {}
56- with torch .device ("cpu" ):
57- iterator = iter (dataloaders [i ])
58- numb_batches = min (nbatches , len (dataloaders [i ]))
59- for _ in range (numb_batches ):
60- try :
61- stat_data = next (iterator )
62- except StopIteration :
63- iterator = iter (dataloaders [i ])
64- stat_data = next (iterator )
65- if (
66- "find_fparam" in stat_data
67- and "fparam" in stat_data
68- and stat_data ["find_fparam" ] == 0.0
69- ):
70- # for model using default fparam
71- stat_data .pop ("fparam" )
72- stat_data .pop ("find_fparam" )
73- for dd in stat_data :
74- if stat_data [dd ] is None :
75- sys_stat [dd ] = None
76- elif isinstance (stat_data [dd ], torch .Tensor ):
77- if dd not in sys_stat :
78- sys_stat [dd ] = []
79- sys_stat [dd ].append (stat_data [dd ])
80- elif isinstance (stat_data [dd ], np .float32 ):
81- sys_stat [dd ] = stat_data [dd ]
82- else :
83- pass
84-
85- for key in sys_stat :
86- if isinstance (sys_stat [key ], np .float32 ):
87- pass
88- elif sys_stat [key ] is None or sys_stat [key ][0 ] is None :
56+ dataloader_lens = [len (dl ) for dl in dataloaders ]
57+ args_list = [
58+ (dataloaders [i ], nbatches , dataloader_lens [i ]) for i in range (len (datasets ))
59+ ]
60+
61+ lst = []
62+ # I/O intensive, set a larger number of workers
63+ with ThreadPoolExecutor (max_workers = 256 ) as executor :
64+ lst = list (executor .map (_process_one_dataset , args_list ))
65+ log .info ("Finished packing data." )
66+ return lst
67+
68+
69+ def _process_one_dataset (args : tuple [Any , int , int ]) -> dict [str , Any ]:
70+ """
71+ Helper function to process a single dataset's dataloader for statistics.
72+ Designed to be called in parallel by a ThreadPoolExecutor.
73+
74+ Parameters
75+ ----------
76+ args : tuple(Any, int, int)
77+ A tuple containing (dataloader, nbatches, dataloader_len)
78+
79+ Returns
80+ -------
81+ dict[str, Any]
82+ The processed sys_stat dictionary for one dataset.
83+ """
84+ dataloader , nbatches , dataloader_len = args
85+ sys_stat = {}
86+
87+ with torch .device ("cpu" ):
88+ iterator = iter (dataloader )
89+ numb_batches = min (nbatches , dataloader_len )
90+
91+ for _ in range (numb_batches ):
92+ try :
93+ stat_data = next (iterator )
94+ except StopIteration :
95+ iterator = iter (dataloader )
96+ stat_data = next (iterator )
97+
98+ if (
99+ "find_fparam" in stat_data
100+ and "fparam" in stat_data
101+ and stat_data ["find_fparam" ] == 0.0
102+ ):
103+ # for model using default fparam
104+ stat_data .pop ("fparam" )
105+ stat_data .pop ("find_fparam" )
106+
107+ for dd in stat_data :
108+ if stat_data [dd ] is None :
109+ sys_stat [dd ] = None
110+ elif isinstance (stat_data [dd ], torch .Tensor ):
111+ if dd not in sys_stat :
112+ sys_stat [dd ] = []
113+ sys_stat [dd ].append (stat_data [dd ])
114+ elif isinstance (stat_data [dd ], np .float32 ):
115+ sys_stat [dd ] = stat_data [dd ]
116+ else :
117+ pass
118+
119+ for key in sys_stat :
120+ if isinstance (sys_stat [key ], np .float32 ):
121+ pass
122+ elif isinstance (sys_stat [key ], list ):
123+ if sys_stat [key ][0 ] is None :
89124 sys_stat [key ] = None
90- elif isinstance ( stat_data [ dd ], torch . Tensor ) :
125+ else :
91126 sys_stat [key ] = torch .cat (sys_stat [key ], dim = 0 )
92- dict_to_device (sys_stat )
93- lst .append (sys_stat )
94- return lst
127+ elif sys_stat [key ] is None :
128+ pass
129+
130+ dict_to_device (sys_stat )
131+ return sys_stat
95132
96133
97134def _restore_from_file (
0 commit comments