1515from __future__ import annotations
1616
1717import glob
18- from datetime import datetime , timedelta
18+ import os .path as osp
19+ from datetime import datetime
20+ from datetime import timedelta
1921from typing import Dict
22+ from typing import List
2023from typing import Optional
2124from typing import Tuple
2225
@@ -31,16 +34,15 @@ class MRMSDataset(io.Dataset):
3134 """Class for MRMS dataset. MRMS day's data is stored in a .h5 file. Each file includes keys "date"/"time_interval"/"dataset".
3235
3336 Args:
34- file_path (str): Data set path.
37+ file_path (str): Dataset path.
3538 input_keys (Tuple[str, ...]): Input keys, usually there is only one, such as ("input",).
3639 label_keys (Tuple[str, ...]): Output keys, usually there is only one, such as ("output",).
3740 weight_dict (Optional[Dict[str, float]]): Weight dictionary. Defaults to None.
3841 date_period (Tuple[str,...], optional): Dates of data. Scale is [start_date, end_date] with format "%Y%m%d". Defaults to ("20230101","20230101").
3942 num_input_timestamps (int, optional): Number of timestamp of label. Defaults to 1.
4043 num_label_timestamps (int, optional): Number of timestamp of label. Defaults to 1.
4144 stride (int, optional): Stride of sampling data. Defaults to 1.
42- transforms (Optional[vision.Compose]): Compose object contains sample wise
43- transform(s). Defaults to None.
45+ transforms (Optional[vision.Compose]): Composed transform functor(s). Defaults to None.
4446
4547 Examples:
4648 >>> import ppsci
@@ -65,12 +67,12 @@ def __init__(
6567 input_keys : Tuple [str , ...],
6668 label_keys : Tuple [str , ...],
6769 weight_dict : Optional [Dict [str , float ]] = None ,
68- date_period : Tuple [str ,...] = ("20230101" ,"20230101" ),
70+ date_period : Tuple [str , ...] = ("20230101" , "20230101" ),
6971 num_input_timestamps : int = 1 ,
7072 num_label_timestamps : int = 1 ,
7173 stride : int = 1 ,
7274 transforms : Optional [vision .Compose ] = None ,
73- ):
75+ ):
7476 super ().__init__ ()
7577 self .file_path = file_path
7678 self .input_keys = input_keys
@@ -81,17 +83,22 @@ def __init__(
8183 self .weight_dict = {key : 1.0 for key in self .label_keys }
8284 self .weight_dict .update (weight_dict )
8385
84- self .date_list = self .get_date_strs (date_period )
86+ self .date_list = self ._get_date_strs (date_period )
8587 self .num_input_timestamps = num_input_timestamps
8688 self .num_label_timestamps = num_label_timestamps
8789 self .stride = stride
8890 self .transforms = transforms
8991
90- self .files = self .read_data (file_path )
92+ self .files = self ._read_data (file_path )
9193 self .num_samples_per_day = self .files [0 ].shape [0 ]
9294 self .num_samples = self .num_samples_per_day * len (self .date_list )
93-
94- def get_date_strs (self , date_period ):
95+
96+ def _get_date_strs (self , date_period : Tuple [str , ...]) -> List :
97+ """Get a string list of all dates within given period.
98+
99+ Args:
100+ date_period (Tuple[str,...]): Dates of data. Scale is [start_date, end_date] with format "%Y%m%d".
101+ """
95102 start_time = datetime .strptime (date_period [0 ], "%Y%m%d" )
96103 end_time = datetime .strptime (date_period [1 ], "%Y%m%d" )
97104 results = []
@@ -102,31 +109,48 @@ def get_date_strs(self, date_period):
102109 current_time += timedelta (days = 1 )
103110 return results
104111
105- def read_data (self , path : str , var = "dataset" ):
106- paths = [path ] if path .endswith (".h5" ) else [_path for _path in glob .glob (path + "/*.h5" ) if _path .split (".h5" )[0 ].split ("_" )[- 1 ] in self .date_list ]
107- assert len (paths ) == len (self .date_list ), f"Data of { len (self .date_list )} mouths wanted but only { len (paths )} mouths be found"
112+ def _read_data (self , path : str ):
113+ if path .endswith (".h5" ):
114+ paths = [path ]
115+ else :
116+ paths = [
117+ _path
118+ for _path in glob .glob (osp .join (path , "*.h5" ))
119+ if _path .split (".h5" )[0 ].split ("_" )[- 1 ] in self .date_list
120+ ]
121+ assert len (paths ) == len (
122+ self .date_list
123+ ), f"Data of { len (self .date_list )} days wanted but only { len (paths )} days be found"
108124 paths .sort ()
109-
110- files = []
111- for _path in paths :
112- _file = h5py .File (_path , "r" )
113- files .append (_file [var ])
125+
126+ files = [h5py .File (_path , "r" )["dataset" ] for _path in paths ]
114127 return files
115128
116129 def __len__ (self ):
117- return self .num_samples // self .stride - self .num_input_timestamps - self .num_label_timestamps + 1
130+ return (
131+ self .num_samples // self .stride
132+ - self .num_input_timestamps
133+ - self .num_label_timestamps
134+ + 1
135+ )
118136
119137 def __getitem__ (self , global_idx ):
120138 global_idx *= self .stride
121- _samples = np .empty ((self .num_input_timestamps + self .num_label_timestamps , * self .files [0 ].shape [1 :]), dtype = paddle .get_default_dtype ())
122- for idx in range (self .num_input_timestamps + self .num_label_timestamps ):
123- sample_idx = global_idx + idx * self .stride
139+ _samples = np .empty (
140+ (
141+ self .num_input_timestamps + self .num_label_timestamps ,
142+ * self .files [0 ].shape [1 :],
143+ ),
144+ dtype = paddle .get_default_dtype (),
145+ )
146+ for idx in range (self .num_input_timestamps + self .num_label_timestamps ):
147+ sample_idx = global_idx + idx * self .stride
124148 day_idx = sample_idx // self .num_samples_per_day
125149 local_idx = sample_idx % self .num_samples_per_day
126- _samples [idx ]= self .files [day_idx ][local_idx ]
150+ _samples [idx ] = self .files [day_idx ][local_idx ]
127151
128- input_item = {self .input_keys [0 ]: _samples [:self .num_input_timestamps ]}
129- label_item = {self .label_keys [0 ]: _samples [self .num_input_timestamps :]}
152+ input_item = {self .input_keys [0 ]: _samples [: self .num_input_timestamps ]}
153+ label_item = {self .label_keys [0 ]: _samples [self .num_input_timestamps :]}
130154
131155 weight_shape = [1 ] * len (next (iter (label_item .values ())).shape )
132156 weight_item = {
@@ -143,17 +167,16 @@ def __getitem__(self, global_idx):
143167
144168
145169class MRMSSampledDataset (io .Dataset ):
146- """Class for MRMS sampled dataset.MRMS one sample's data is stored in a .h5 file. Each file includes keys "date"/"time_interval"/"dataset".
170+ """Class for MRMS sampled dataset. MRMS one sample's data is stored in a .h5 file. Each file includes keys "date"/"time_interval"/"dataset".
147171 The class just return data by input_item and values of label_item are empty for all label_keys.
148172
149173 Args:
150- file_path (str): Data set path.
174+ file_path (str): Dataset path.
151175 input_keys (Tuple[str, ...]): Input keys, such as ("input",).
152176 label_keys (Tuple[str, ...]): Output keys, such as ("output",).
153177 weight_dict (Optional[Dict[str, float]]): Weight dictionary. Defaults to None.
154178 num_total_timestamps (int, optional): Number of timestamp of input+label. Defaults to 1.
155- transforms (Optional[vision.Compose]): Compose object contains sample wise
156- transform(s). Defaults to None.
179+ transforms (Optional[vision.Compose]): Composed transform functor(s). Defaults to None.
157180
158181 Examples:
159182 >>> import ppsci
@@ -192,16 +215,13 @@ def __init__(
192215 self .num_total_timestamps = num_total_timestamps
193216 self .transforms = transforms
194217
195- self .files = self .read_data (file_path )
218+ self .files = self ._read_data (file_path )
196219 self .num_samples = len (self .files )
197220
198- def read_data (self , path : str ):
199- paths = glob .glob (path + "/ *.h5" )
221+ def _read_data (self , path : str ):
222+ paths = glob .glob (osp . join ( path , " *.h5") )
200223 paths .sort ()
201- files = []
202- for _path in paths :
203- _file = h5py .File (_path , "r" )
204- files .append (_file )
224+ files = [h5py .File (_path , "r" )["dataset" ] for _path in paths ]
205225 return files
206226
207227 def __len__ (self ):
@@ -210,20 +230,15 @@ def __len__(self):
210230 def __getitem__ (self , global_idx ):
211231 _samples = []
212232 for idx in range (global_idx , global_idx + self .num_total_timestamps ):
213- _samples .append (np .expand_dims (self .files [idx ][ "dataset" ], axis = 0 ))
233+ _samples .append (np .expand_dims (self .files [idx ], axis = 0 ))
214234
215- input_item = {self .input_keys [0 ]: np .concatenate (_samples , axis = 0 ).astype (paddle .get_default_dtype ())}
235+ input_item = {
236+ self .input_keys [0 ]: np .concatenate (_samples , axis = 0 ).astype (
237+ paddle .get_default_dtype ()
238+ )
239+ }
216240 label_item = {}
217- for key in self .label_keys :
218- label_item [key ] = np .asarray ([], paddle .get_default_dtype ())
219-
220241 weight_item = {}
221- if len (label_item ) > 0 :
222- weight_shape = [1 ] * len (next (iter (label_item .values ())).shape )
223- weight_item = {
224- key : np .full (weight_shape , value , paddle .get_default_dtype ())
225- for key , value in self .weight_dict .items ()
226- }
227242
228243 if self .transforms is not None :
229244 input_item , label_item , weight_item = self .transforms (
0 commit comments