Skip to content

Commit 594e55e

Browse files
add YingLong model (#771)
* add yinglong inference code * add readme file for yinglong * update yinglong predict code * add copyright in predict_12layers.py * update readme * add reference for timefeatures.py * update predict_12layers and readme
1 parent 205cc6f commit 594e55e

File tree

3 files changed

+420
-0
lines changed

3 files changed

+420
-0
lines changed

examples/yinglong/README.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Skillful High Resolution Regional Short Term Forecasting with Boundary Smoothing
2+
3+
YingLong, a high-resolution, short-term regional weather forecasting, artificial-intelligence-based model, which is capable of hourly predicting weather fields including wind speed, temperature, and specific humidity at a 3km resolution. YingLong utilizes a parallel structure of global and local blocks to capture multiscale meteorological features and is trained on analysis dataset. Additionally, the necessary information around the regional boundary is introduced to YingLong through the boundary smoothing strategy, which significantly improves the regional forecasting results. By comparing forecast results with those from WRF-ARW, one of the best numerical prediction models, YingLong demonstrates superior forecasting performances in most cases, especially on surface variables.
4+
5+
This code is the implementation of YingLong. We select the southeastern region of the United States, which is around the range of 110-130E, 15-35N, with 440 × 408 grid points in Lambert projection.
6+
7+
## Installation
8+
9+
### 1. Install PaddlePaddle
10+
11+
Please install the <font color="red"><b>2.5.2</b></font> version of PaddlePaddle according to your environment on the official website of [PaddlePaddle](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/develop/install/pip/linux-pip_en.html).
12+
13+
For example, if your environment is linux and CUDA 11.2, you can install PaddlePaddle by the following command.
14+
15+
``` shell
16+
python -m pip install paddlepaddle-gpu==2.5.2.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
17+
```
18+
19+
After installation, run the following command to verify if PaddlePaddle has been successfully installed.
20+
21+
``` shell
22+
python -c "import paddle; paddle.utils.run_check()"
23+
```
24+
25+
If `"PaddlePaddle is installed successfully! Let's start deep learning with PaddlePaddle now."` appears, to verify that the installation was successful.
26+
27+
### 2. Install PaddleScience
28+
29+
Clone the code of PaddleScience from [here](https://github.com/PaddlePaddle/PaddleScience.git).
30+
31+
``` shell
32+
git clone -b develop https://github.com/PaddlePaddle/PaddleScience.git
33+
```
34+
35+
## Example Usage
36+
37+
### 1. Download the data and model weights
38+
39+
``` shell
40+
cd PaddleScience
41+
wget https://paddle-org.bj.bcebos.com/paddlescience/datasets/yinglong/hrrr_example_24vars.tar
42+
tar -xvf hrrr_example_24vars.tar
43+
wget https://paddle-org.bj.bcebos.com/paddlescience/models/yinglong/yinglong_models.tar
44+
tar -xvf yinglong_models.tar
45+
```
46+
47+
### 2. Run the code
48+
49+
The following code runs the Yinglong model, and the model output will be saved in 'output/result.npy'.
50+
51+
``` shell
52+
export PYTHONPATH=$PWD
53+
python examples/yinglong/predict_12layers.py
54+
```
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
from os import path as osp
17+
18+
import h5py
19+
import numpy as np
20+
import paddle
21+
import paddle.inference as paddle_infer
22+
import pandas as pd
23+
from packaging import version
24+
25+
from examples.yinglong.timefeatures import time_features
26+
from ppsci.utils import logger
27+
28+
29+
class YingLong:
30+
def __init__(
31+
self, model_file: str, params_file: str, mean_path: str, std_path: str
32+
):
33+
self.model_file = model_file
34+
self.params_file = params_file
35+
36+
config = paddle_infer.Config(model_file, params_file)
37+
config.switch_ir_optim(False)
38+
config.enable_use_gpu(100, 0)
39+
config.enable_memory_optim()
40+
41+
self.predictor = paddle_infer.create_predictor(config)
42+
43+
# get input names and data handles
44+
self.input_names = self.predictor.get_input_names()
45+
self.input_data_handle = self.predictor.get_input_handle(self.input_names[0])
46+
self.time_stamps_handle = self.predictor.get_input_handle(self.input_names[1])
47+
self.nwp_data_handle = self.predictor.get_input_handle(self.input_names[2])
48+
49+
# get output names and data handles
50+
self.output_names = self.predictor.get_output_names()
51+
self.output_handle = self.predictor.get_output_handle(self.output_names[0])
52+
53+
# load mean and std data
54+
self.mean = np.load(mean_path).reshape(-1, 1, 1).astype(np.float32)
55+
self.std = np.load(std_path).reshape(-1, 1, 1).astype(np.float32)
56+
57+
def _preprocess_data(self, input_data, time_stamps, nwp_data):
58+
# normalize data
59+
input_data = (input_data - self.mean) / self.std
60+
nwp_data = (nwp_data - self.mean) / self.std
61+
62+
# process time stamps
63+
for i in range(len(time_stamps)):
64+
time_stamps[i] = pd.DataFrame({"date": time_stamps[i]})
65+
time_stamps[i] = time_features(time_stamps[i], timeenc=1, freq="h").astype(
66+
np.float32
67+
)
68+
time_stamps = np.asarray(time_stamps)
69+
return input_data, time_stamps, nwp_data
70+
71+
def _postprocess_data(self, data):
72+
# denormalize data
73+
data = data * self.std + self.mean
74+
return data
75+
76+
def __call__(self, input_data, time_stamp, nwp_data):
77+
# preprocess data
78+
input_data, time_stamps, nwp_data = self._preprocess_data(
79+
input_data, time_stamp, nwp_data
80+
)
81+
82+
# set input data
83+
self.input_data_handle.copy_from_cpu(input_data)
84+
self.time_stamps_handle.copy_from_cpu(time_stamps)
85+
self.nwp_data_handle.copy_from_cpu(nwp_data)
86+
87+
# run predictor
88+
self.predictor.run()
89+
90+
# get output data
91+
output_data = self.output_handle.copy_to_cpu()
92+
93+
# postprocess data
94+
output_data = self._postprocess_data(output_data)
95+
return output_data
96+
97+
98+
def parse_args():
99+
parser = argparse.ArgumentParser()
100+
parser.add_argument(
101+
"--model_file",
102+
type=str,
103+
default="./yinglong_models/yinglong_12.pdmodel",
104+
help="Model filename",
105+
)
106+
parser.add_argument(
107+
"--params_file",
108+
type=str,
109+
default="./yinglong_models/yinglong_12.pdiparams",
110+
help="Parameter filename",
111+
)
112+
parser.add_argument(
113+
"--mean_path",
114+
type=str,
115+
default="./hrrr_example_24vars/stat/mean_crop.npy",
116+
help="Mean filename",
117+
)
118+
parser.add_argument(
119+
"--std_path",
120+
type=str,
121+
default="./hrrr_example_24vars/stat/std_crop.npy",
122+
help="Standard deviation filename",
123+
)
124+
parser.add_argument(
125+
"--input_file",
126+
type=str,
127+
default="./hrrr_example_24vars/valid/2022/01/01.h5",
128+
help="Input filename",
129+
)
130+
parser.add_argument(
131+
"--init_time", type=str, default="2022/01/01/00", help="Init time"
132+
)
133+
parser.add_argument(
134+
"--nwp_file",
135+
type=str,
136+
default="./hrrr_example_24vars/nwp_convert/2022/01/01/00.h5",
137+
help="NWP filename",
138+
)
139+
parser.add_argument(
140+
"--num_timestamps", type=int, default=22, help="Number of timestamps"
141+
)
142+
parser.add_argument(
143+
"--output_path", type=str, default="output", help="Output file path"
144+
)
145+
146+
return parser.parse_args()
147+
148+
149+
def main():
150+
args = parse_args()
151+
logger.init_logger("ppsci", osp.join(args.output_path, "predict.log"), "info")
152+
if version.Version(paddle.__version__) != version.Version("2.5.2"):
153+
logger.error(
154+
f"Your Paddle version is {paddle.__version__}, but this code currently "
155+
"only supports PaddlePaddle 2.5.2. The latest version of Paddle will be "
156+
"supported as soon as possible."
157+
)
158+
exit()
159+
160+
num_timestamps = args.num_timestamps
161+
162+
# create predictor
163+
predictor = YingLong(
164+
args.model_file, args.params_file, args.mean_path, args.std_path
165+
)
166+
167+
# load data
168+
input_file = h5py.File(args.input_file, "r")["fields"]
169+
nwp_file = h5py.File(args.nwp_file, "r")["fields"]
170+
input_data = input_file[0:1]
171+
nwp_data = nwp_file[0:num_timestamps]
172+
173+
# create time stamps
174+
cur_time = pd.to_datetime(args.init_time, format="%Y/%m/%d/%H")
175+
time_stamps = [[cur_time]]
176+
for _ in range(num_timestamps):
177+
cur_time += pd.Timedelta(hours=1)
178+
time_stamps.append([cur_time])
179+
180+
# run predictor
181+
output_data = predictor(input_data, time_stamps, nwp_data)
182+
183+
save_path = osp.join(args.output_path, "result.npy")
184+
logger.info(f"Save output to {save_path}")
185+
np.save(save_path, output_data)
186+
187+
188+
if __name__ == "__main__":
189+
main()

0 commit comments

Comments
 (0)