Skip to content

Commit 3affeae

Browse files
committed
[Add]MRMSDataset
1 parent 3e1d0ad commit 3affeae

File tree

2 files changed

+237
-0
lines changed

2 files changed

+237
-0
lines changed

ppsci/data/dataset/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from ppsci.data.dataset.era5_dataset import ERA5SampledDataset
2626
from ppsci.data.dataset.mat_dataset import IterableMatDataset
2727
from ppsci.data.dataset.mat_dataset import MatDataset
28+
from ppsci.data.dataset.mrms_dataset import MRMSDataset
29+
from ppsci.data.dataset.mrms_dataset import MRMSSampledDataset
2830
from ppsci.data.dataset.npz_dataset import IterableNPZDataset
2931
from ppsci.data.dataset.npz_dataset import NPZDataset
3032
from ppsci.data.dataset.radar_dataset import RadarDataset
@@ -47,6 +49,8 @@
4749
"ERA5SampledDataset",
4850
"IterableMatDataset",
4951
"MatDataset",
52+
"MRMSDataset",
53+
"MRMSSampledDataset",
5054
"IterableNPZDataset",
5155
"NPZDataset",
5256
"CylinderDataset",

ppsci/data/dataset/mrms_dataset.py

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import glob
18+
from datetime import datetime, timedelta
19+
from typing import Dict
20+
from typing import Optional
21+
from typing import Tuple
22+
23+
import h5py
24+
import numpy as np
25+
import paddle
26+
from paddle import io
27+
from paddle import vision
28+
29+
30+
class MRMSDataset(io.Dataset):
31+
"""Class for MRMS dataset. MRMS day's data is stored in a .h5 file. Each file includes keys "date"/"time_interval"/"dataset".
32+
33+
Args:
34+
file_path (str): Data set path.
35+
input_keys (Tuple[str, ...]): Input keys, usually there is only one, such as ("input",).
36+
label_keys (Tuple[str, ...]): Output keys, usually there is only one, such as ("output",).
37+
weight_dict (Optional[Dict[str, float]]): Weight dictionary. Defaults to None.
38+
date_period (Tuple[str,...], optional): Dates of data. Scale is [start_date, end_date] with format "%Y%m%d". Defaults to ("20230101","20230101").
39+
num_input_timestamps (int, optional): Number of timestamp of label. Defaults to 1.
40+
num_label_timestamps (int, optional): Number of timestamp of label. Defaults to 1.
41+
stride (int, optional): Stride of sampling data. Defaults to 1.
42+
transforms (Optional[vision.Compose]): Compose object contains sample wise
43+
transform(s). Defaults to None.
44+
45+
Examples:
46+
>>> import ppsci
47+
>>> dataset = ppsci.data.dataset.MRMSDataset(
48+
... "file_path": "/path/to/MRMSDataset",
49+
... "input_keys": ("input",),
50+
... "label_keys": ("output",),
51+
... "date_period": ("20230101","20230131"),
52+
... "num_input_timestamps": 9,
53+
... "num_label_timestamps": 20,
54+
... "transforms": transform,
55+
... "stride": 1,
56+
... ) # doctest: +SKIP
57+
"""
58+
59+
# Whether support batch indexing for speeding up fetching process.
60+
batch_index: bool = False
61+
62+
def __init__(
63+
self,
64+
file_path: str,
65+
input_keys: Tuple[str, ...],
66+
label_keys: Tuple[str, ...],
67+
weight_dict: Optional[Dict[str, float]] = None,
68+
date_period: Tuple[str,...] = ("20230101","20230101"),
69+
num_input_timestamps: int = 1,
70+
num_label_timestamps: int = 1,
71+
stride: int = 1,
72+
transforms: Optional[vision.Compose] = None,
73+
):
74+
super().__init__()
75+
self.file_path = file_path
76+
self.input_keys = input_keys
77+
self.label_keys = label_keys
78+
79+
self.weight_dict = {} if weight_dict is None else weight_dict
80+
if weight_dict is not None:
81+
self.weight_dict = {key: 1.0 for key in self.label_keys}
82+
self.weight_dict.update(weight_dict)
83+
84+
self.date_list = self.get_date_strs(date_period)
85+
self.num_input_timestamps = num_input_timestamps
86+
self.num_label_timestamps = num_label_timestamps
87+
self.stride = stride
88+
self.transforms = transforms
89+
90+
self.files = self.read_data(file_path)
91+
self.num_samples_per_day = self.files[0].shape[0]
92+
self.num_samples = self.num_samples_per_day * len(self.date_list)
93+
94+
def get_date_strs(self, date_period):
95+
start_time = datetime.strptime(date_period[0], "%Y%m%d")
96+
end_time = datetime.strptime(date_period[1], "%Y%m%d")
97+
results = []
98+
current_time = start_time
99+
while current_time <= end_time:
100+
date_str = current_time.strftime("%Y%m%d")
101+
results.append(date_str)
102+
current_time += timedelta(days=1)
103+
return results
104+
105+
def read_data(self, path: str, var="dataset"):
106+
paths = [path] if path.endswith(".h5") else [_path for _path in glob.glob(path + "/*.h5") if _path.split(".h5")[0].split("_")[-1] in self.date_list]
107+
assert len(paths) == len(self.date_list), f"Data of {len(self.date_list)} mouths wanted but only {len(paths)} mouths be found"
108+
paths.sort()
109+
110+
files = []
111+
for _path in paths:
112+
_file = h5py.File(_path, "r")
113+
files.append(_file[var])
114+
return files
115+
116+
def __len__(self):
117+
return self.num_samples//self.stride - self.num_input_timestamps - self.num_label_timestamps + 1
118+
119+
def __getitem__(self, global_idx):
120+
global_idx *= self.stride
121+
_samples = np.empty((self.num_input_timestamps + self.num_label_timestamps, *self.files[0].shape[1:]), dtype=paddle.get_default_dtype())
122+
for idx in range(self.num_input_timestamps+self.num_label_timestamps):
123+
sample_idx = global_idx + idx*self.stride
124+
day_idx = sample_idx // self.num_samples_per_day
125+
local_idx = sample_idx % self.num_samples_per_day
126+
_samples[idx]=self.files[day_idx][local_idx]
127+
128+
input_item = {self.input_keys[0]: _samples[:self.num_input_timestamps]}
129+
label_item = {self.label_keys[0]: _samples[self.num_input_timestamps:]}
130+
131+
weight_shape = [1] * len(next(iter(label_item.values())).shape)
132+
weight_item = {
133+
key: np.full(weight_shape, value, paddle.get_default_dtype())
134+
for key, value in self.weight_dict.items()
135+
}
136+
137+
if self.transforms is not None:
138+
input_item, label_item, weight_item = self.transforms(
139+
input_item, label_item, weight_item
140+
)
141+
142+
return input_item, label_item, weight_item
143+
144+
145+
class MRMSSampledDataset(io.Dataset):
146+
"""Class for MRMS sampled dataset.MRMS one sample's data is stored in a .h5 file. Each file includes keys "date"/"time_interval"/"dataset".
147+
The class just return data by input_item and values of label_item are empty for all label_keys.
148+
149+
Args:
150+
file_path (str): Data set path.
151+
input_keys (Tuple[str, ...]): Input keys, such as ("input",).
152+
label_keys (Tuple[str, ...]): Output keys, such as ("output",).
153+
weight_dict (Optional[Dict[str, float]]): Weight dictionary. Defaults to None.
154+
num_total_timestamps (int, optional): Number of timestamp of input+label. Defaults to 1.
155+
transforms (Optional[vision.Compose]): Compose object contains sample wise
156+
transform(s). Defaults to None.
157+
158+
Examples:
159+
>>> import ppsci
160+
>>> dataset = ppsci.data.dataset.MRMSSampledDataset(
161+
... "file_path": "/path/to/MRMSSampledDataset",
162+
... "input_keys": ("input",),
163+
... "label_keys": ("output",),
164+
... "num_total_timestamps": 29,
165+
... ) # doctest: +SKIP
166+
>>> # get the length of the dataset
167+
>>> dataset_size = len(dataset)
168+
>>> # get the first sample of the data
169+
>>> first_sample = dataset[0]
170+
>>> print("First sample:", first_sample)
171+
"""
172+
173+
def __init__(
174+
self,
175+
file_path: str,
176+
input_keys: Tuple[str, ...],
177+
label_keys: Tuple[str, ...],
178+
weight_dict: Optional[Dict[str, float]] = None,
179+
num_total_timestamps: int = 1,
180+
transforms: Optional[vision.Compose] = None,
181+
):
182+
super().__init__()
183+
self.file_path = file_path
184+
self.input_keys = input_keys
185+
self.label_keys = label_keys
186+
187+
self.weight_dict = {} if weight_dict is None else weight_dict
188+
if weight_dict is not None:
189+
self.weight_dict = {key: 1.0 for key in self.label_keys}
190+
self.weight_dict.update(weight_dict)
191+
192+
self.num_total_timestamps = num_total_timestamps
193+
self.transforms = transforms
194+
195+
self.files = self.read_data(file_path)
196+
self.num_samples = len(self.files)
197+
198+
def read_data(self, path: str):
199+
paths = glob.glob(path + "/*.h5")
200+
paths.sort()
201+
files = []
202+
for _path in paths:
203+
_file = h5py.File(_path, "r")
204+
files.append(_file)
205+
return files
206+
207+
def __len__(self):
208+
return self.num_samples - self.num_total_timestamps + 1
209+
210+
def __getitem__(self, global_idx):
211+
_samples = []
212+
for idx in range(global_idx, global_idx + self.num_total_timestamps):
213+
_samples.append(np.expand_dims(self.files[idx]["dataset"],axis=0))
214+
215+
input_item = {self.input_keys[0]: np.concatenate(_samples, axis=0).astype(paddle.get_default_dtype())}
216+
label_item = {}
217+
for key in self.label_keys:
218+
label_item[key] = np.asarray([], paddle.get_default_dtype())
219+
220+
weight_item = {}
221+
if len(label_item) > 0:
222+
weight_shape = [1] * len(next(iter(label_item.values())).shape)
223+
weight_item = {
224+
key: np.full(weight_shape, value, paddle.get_default_dtype())
225+
for key, value in self.weight_dict.items()
226+
}
227+
228+
if self.transforms is not None:
229+
input_item, label_item, weight_item = self.transforms(
230+
input_item, label_item, weight_item
231+
)
232+
233+
return input_item, label_item, weight_item

0 commit comments

Comments
 (0)