Skip to content

Commit f9298c4

Browse files
committed
fix code and style
1 parent 3affeae commit f9298c4

File tree

1 file changed

+62
-47
lines changed

1 file changed

+62
-47
lines changed

ppsci/data/dataset/mrms_dataset.py

Lines changed: 62 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
from __future__ import annotations
1616

1717
import glob
18-
from datetime import datetime, timedelta
18+
import os.path as osp
19+
from datetime import datetime
20+
from datetime import timedelta
1921
from typing import Dict
22+
from typing import List
2023
from typing import Optional
2124
from typing import Tuple
2225

@@ -31,16 +34,15 @@ class MRMSDataset(io.Dataset):
3134
"""Class for MRMS dataset. MRMS day's data is stored in a .h5 file. Each file includes keys "date"/"time_interval"/"dataset".
3235
3336
Args:
34-
file_path (str): Data set path.
37+
file_path (str): Dataset path.
3538
input_keys (Tuple[str, ...]): Input keys, usually there is only one, such as ("input",).
3639
label_keys (Tuple[str, ...]): Output keys, usually there is only one, such as ("output",).
3740
weight_dict (Optional[Dict[str, float]]): Weight dictionary. Defaults to None.
3841
date_period (Tuple[str,...], optional): Dates of data. Scale is [start_date, end_date] with format "%Y%m%d". Defaults to ("20230101","20230101").
3942
num_input_timestamps (int, optional): Number of timestamp of label. Defaults to 1.
4043
num_label_timestamps (int, optional): Number of timestamp of label. Defaults to 1.
4144
stride (int, optional): Stride of sampling data. Defaults to 1.
42-
transforms (Optional[vision.Compose]): Compose object contains sample wise
43-
transform(s). Defaults to None.
45+
transforms (Optional[vision.Compose]): Composed transform functor(s). Defaults to None.
4446
4547
Examples:
4648
>>> import ppsci
@@ -65,12 +67,12 @@ def __init__(
6567
input_keys: Tuple[str, ...],
6668
label_keys: Tuple[str, ...],
6769
weight_dict: Optional[Dict[str, float]] = None,
68-
date_period: Tuple[str,...] = ("20230101","20230101"),
70+
date_period: Tuple[str, ...] = ("20230101", "20230101"),
6971
num_input_timestamps: int = 1,
7072
num_label_timestamps: int = 1,
7173
stride: int = 1,
7274
transforms: Optional[vision.Compose] = None,
73-
):
75+
):
7476
super().__init__()
7577
self.file_path = file_path
7678
self.input_keys = input_keys
@@ -81,17 +83,22 @@ def __init__(
8183
self.weight_dict = {key: 1.0 for key in self.label_keys}
8284
self.weight_dict.update(weight_dict)
8385

84-
self.date_list = self.get_date_strs(date_period)
86+
self.date_list = self._get_date_strs(date_period)
8587
self.num_input_timestamps = num_input_timestamps
8688
self.num_label_timestamps = num_label_timestamps
8789
self.stride = stride
8890
self.transforms = transforms
8991

90-
self.files = self.read_data(file_path)
92+
self.files = self._read_data(file_path)
9193
self.num_samples_per_day = self.files[0].shape[0]
9294
self.num_samples = self.num_samples_per_day * len(self.date_list)
93-
94-
def get_date_strs(self, date_period):
95+
96+
def _get_date_strs(self, date_period: Tuple[str, ...]) -> List:
97+
"""Get a string list of all dates within given period.
98+
99+
Args:
100+
date_period (Tuple[str,...]): Dates of data. Scale is [start_date, end_date] with format "%Y%m%d".
101+
"""
95102
start_time = datetime.strptime(date_period[0], "%Y%m%d")
96103
end_time = datetime.strptime(date_period[1], "%Y%m%d")
97104
results = []
@@ -102,31 +109,48 @@ def get_date_strs(self, date_period):
102109
current_time += timedelta(days=1)
103110
return results
104111

105-
def read_data(self, path: str, var="dataset"):
106-
paths = [path] if path.endswith(".h5") else [_path for _path in glob.glob(path + "/*.h5") if _path.split(".h5")[0].split("_")[-1] in self.date_list]
107-
assert len(paths) == len(self.date_list), f"Data of {len(self.date_list)} mouths wanted but only {len(paths)} mouths be found"
112+
def _read_data(self, path: str):
113+
if path.endswith(".h5"):
114+
paths = [path]
115+
else:
116+
paths = [
117+
_path
118+
for _path in glob.glob(osp.join(path, "*.h5"))
119+
if _path.split(".h5")[0].split("_")[-1] in self.date_list
120+
]
121+
assert len(paths) == len(
122+
self.date_list
123+
), f"Data of {len(self.date_list)} days wanted but only {len(paths)} days be found"
108124
paths.sort()
109-
110-
files = []
111-
for _path in paths:
112-
_file = h5py.File(_path, "r")
113-
files.append(_file[var])
125+
126+
files = [h5py.File(_path, "r")["dataset"] for _path in paths]
114127
return files
115128

116129
def __len__(self):
117-
return self.num_samples//self.stride - self.num_input_timestamps - self.num_label_timestamps + 1
130+
return (
131+
self.num_samples // self.stride
132+
- self.num_input_timestamps
133+
- self.num_label_timestamps
134+
+ 1
135+
)
118136

119137
def __getitem__(self, global_idx):
120138
global_idx *= self.stride
121-
_samples = np.empty((self.num_input_timestamps + self.num_label_timestamps, *self.files[0].shape[1:]), dtype=paddle.get_default_dtype())
122-
for idx in range(self.num_input_timestamps+self.num_label_timestamps):
123-
sample_idx = global_idx + idx*self.stride
139+
_samples = np.empty(
140+
(
141+
self.num_input_timestamps + self.num_label_timestamps,
142+
*self.files[0].shape[1:],
143+
),
144+
dtype=paddle.get_default_dtype(),
145+
)
146+
for idx in range(self.num_input_timestamps + self.num_label_timestamps):
147+
sample_idx = global_idx + idx * self.stride
124148
day_idx = sample_idx // self.num_samples_per_day
125149
local_idx = sample_idx % self.num_samples_per_day
126-
_samples[idx]=self.files[day_idx][local_idx]
150+
_samples[idx] = self.files[day_idx][local_idx]
127151

128-
input_item = {self.input_keys[0]: _samples[:self.num_input_timestamps]}
129-
label_item = {self.label_keys[0]: _samples[self.num_input_timestamps:]}
152+
input_item = {self.input_keys[0]: _samples[: self.num_input_timestamps]}
153+
label_item = {self.label_keys[0]: _samples[self.num_input_timestamps :]}
130154

131155
weight_shape = [1] * len(next(iter(label_item.values())).shape)
132156
weight_item = {
@@ -143,17 +167,16 @@ def __getitem__(self, global_idx):
143167

144168

145169
class MRMSSampledDataset(io.Dataset):
146-
"""Class for MRMS sampled dataset.MRMS one sample's data is stored in a .h5 file. Each file includes keys "date"/"time_interval"/"dataset".
170+
"""Class for MRMS sampled dataset. MRMS one sample's data is stored in a .h5 file. Each file includes keys "date"/"time_interval"/"dataset".
147171
The class just return data by input_item and values of label_item are empty for all label_keys.
148172
149173
Args:
150-
file_path (str): Data set path.
174+
file_path (str): Dataset path.
151175
input_keys (Tuple[str, ...]): Input keys, such as ("input",).
152176
label_keys (Tuple[str, ...]): Output keys, such as ("output",).
153177
weight_dict (Optional[Dict[str, float]]): Weight dictionary. Defaults to None.
154178
num_total_timestamps (int, optional): Number of timestamp of input+label. Defaults to 1.
155-
transforms (Optional[vision.Compose]): Compose object contains sample wise
156-
transform(s). Defaults to None.
179+
transforms (Optional[vision.Compose]): Composed transform functor(s). Defaults to None.
157180
158181
Examples:
159182
>>> import ppsci
@@ -192,16 +215,13 @@ def __init__(
192215
self.num_total_timestamps = num_total_timestamps
193216
self.transforms = transforms
194217

195-
self.files = self.read_data(file_path)
218+
self.files = self._read_data(file_path)
196219
self.num_samples = len(self.files)
197220

198-
def read_data(self, path: str):
199-
paths = glob.glob(path + "/*.h5")
221+
def _read_data(self, path: str):
222+
paths = glob.glob(osp.join(path, "*.h5"))
200223
paths.sort()
201-
files = []
202-
for _path in paths:
203-
_file = h5py.File(_path, "r")
204-
files.append(_file)
224+
files = [h5py.File(_path, "r")["dataset"] for _path in paths]
205225
return files
206226

207227
def __len__(self):
@@ -210,20 +230,15 @@ def __len__(self):
210230
def __getitem__(self, global_idx):
211231
_samples = []
212232
for idx in range(global_idx, global_idx + self.num_total_timestamps):
213-
_samples.append(np.expand_dims(self.files[idx]["dataset"],axis=0))
233+
_samples.append(np.expand_dims(self.files[idx], axis=0))
214234

215-
input_item = {self.input_keys[0]: np.concatenate(_samples, axis=0).astype(paddle.get_default_dtype())}
235+
input_item = {
236+
self.input_keys[0]: np.concatenate(_samples, axis=0).astype(
237+
paddle.get_default_dtype()
238+
)
239+
}
216240
label_item = {}
217-
for key in self.label_keys:
218-
label_item[key] = np.asarray([], paddle.get_default_dtype())
219-
220241
weight_item = {}
221-
if len(label_item) > 0:
222-
weight_shape = [1] * len(next(iter(label_item.values())).shape)
223-
weight_item = {
224-
key: np.full(weight_shape, value, paddle.get_default_dtype())
225-
for key, value in self.weight_dict.items()
226-
}
227242

228243
if self.transforms is not None:
229244
input_item, label_item, weight_item = self.transforms(

0 commit comments

Comments
 (0)