1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import argparse
1615from os import path as osp
1716
1817import h5py
18+ import hydra
1919import numpy as np
2020import paddle
2121import pandas as pd
22+ from omegaconf import DictConfig
2223from packaging import version
2324
2425from examples .yinglong .plot import save_plot_weather_from_dict
25- from examples .yinglong .predictor import YingLong
26+ from examples .yinglong .predictor import YingLongPredictor
2627from ppsci .utils import logger
2728
2829
29- def parse_args ():
30- parser = argparse .ArgumentParser ()
31- parser .add_argument (
32- "--model_file" ,
33- type = str ,
34- default = "./yinglong_models/yinglong_12.pdmodel" ,
35- help = "Model filename" ,
36- )
37- parser .add_argument (
38- "--params_file" ,
39- type = str ,
40- default = "./yinglong_models/yinglong_12.pdiparams" ,
41- help = "Parameter filename" ,
42- )
43- parser .add_argument (
44- "--mean_path" ,
45- type = str ,
46- default = "./hrrr_example_24vars/stat/mean_crop.npy" ,
47- help = "Mean filename" ,
48- )
49- parser .add_argument (
50- "--std_path" ,
51- type = str ,
52- default = "./hrrr_example_24vars/stat/std_crop.npy" ,
53- help = "Standard deviation filename" ,
54- )
55- parser .add_argument (
56- "--input_file" ,
57- type = str ,
58- default = "./hrrr_example_24vars/valid/2022/01/01.h5" ,
59- help = "Input filename" ,
60- )
61- parser .add_argument (
62- "--init_time" , type = str , default = "2022/01/01/00" , help = "Init time"
63- )
64- parser .add_argument (
65- "--nwp_file" ,
66- type = str ,
67- default = "./hrrr_example_24vars/nwp_convert/2022/01/01/00.h5" ,
68- help = "NWP filename" ,
69- )
70- parser .add_argument (
71- "--num_timestamps" , type = int , default = 22 , help = "Number of timestamps"
72- )
73- parser .add_argument (
74- "--output_path" , type = str , default = "output" , help = "Output file path"
75- )
76-
77- return parser .parse_args ()
78-
79-
80- def main ():
81- args = parse_args ()
82- logger .init_logger ("ppsci" , osp .join (args .output_path , "predict.log" ), "info" )
30+ def inference (cfg : DictConfig ):
8331 # log paddlepaddle's version
8432 if version .Version (paddle .__version__ ) != version .Version ("0.0.0" ):
8533 paddle_version = paddle .__version__
@@ -93,19 +41,17 @@ def main():
9341
9442 logger .info (f"Using paddlepaddle { paddle_version } " )
9543
96- num_timestamps = args .num_timestamps
44+ num_timestamps = cfg . INFER .num_timestamps
9745 # create predictor
98- predictor = YingLong (
99- args .model_file , args .params_file , args .mean_path , args .std_path
100- )
46+ predictor = YingLongPredictor (cfg )
10147
10248 # load data
10349 # HRRR Crop use 24 atmospheric variable,their index in the dataset is from 0 to 23.
10450 # The variable name is 'z50', 'z500', 'z850', 'z1000', 't50', 't500', 't850', 'z1000',
10551 # 's50', 's500', 's850', 's1000', 'u50', 'u500', 'u850', 'u1000', 'v50', 'v500',
10652 # 'v850', 'v1000', 'mslp', 'u10', 'v10', 't2m'.
107- input_file = h5py .File (args .input_file , "r" )["fields" ]
108- nwp_file = h5py .File (args .nwp_file , "r" )["fields" ]
53+ input_file = h5py .File (cfg . INFER .input_file , "r" )["fields" ]
54+ nwp_file = h5py .File (cfg . INFER .nwp_file , "r" )["fields" ]
10955
11056 # input_data.shape: (1, 24, 440, 408)
11157 input_data = input_file [0 :1 ]
@@ -115,18 +61,18 @@ def main():
11561 ground_truth = input_file [1 : num_timestamps + 1 ]
11662
11763 # create time stamps
118- cur_time = pd .to_datetime (args .init_time , format = "%Y/%m/%d/%H" )
64+ cur_time = pd .to_datetime (cfg . INFER .init_time , format = "%Y/%m/%d/%H" )
11965 time_stamps = [[cur_time ]]
12066 for _ in range (num_timestamps ):
12167 cur_time += pd .Timedelta (hours = 1 )
12268 time_stamps .append ([cur_time ])
12369
12470 # run predictor
125- pred_data = predictor (input_data , time_stamps , nwp_data )
71+ pred_data = predictor . predict (input_data , time_stamps , nwp_data )
12672 pred_data = pred_data .squeeze (axis = 1 ) # (num_timestamps, 24, 440, 408)
12773
12874 # save predict data
129- save_path = osp .join (args . output_path , "result.npy" )
75+ save_path = osp .join (cfg . output_dir , "result.npy" )
13076 np .save (save_path , pred_data )
13177 logger .info (f"Save output to { save_path } " )
13278
@@ -139,15 +85,15 @@ def main():
13985 data_dict = {}
14086 visu_keys = []
14187 for i in range (num_timestamps ):
142- visu_key = f"Init time: { args .init_time } h\n Ground truth: { i + 1 } h"
88+ visu_key = f"Init time: { cfg . INFER .init_time } h\n Ground truth: { i + 1 } h"
14389 visu_keys .append (visu_key )
14490 data_dict [visu_key ] = ground_truth_wind [i ]
145- visu_key = f"Init time: { args .init_time } h\n YingLong-12 Layers: { i + 1 } h"
91+ visu_key = f"Init time: { cfg . INFER .init_time } h\n YingLong-12 Layers: { i + 1 } h"
14692 visu_keys .append (visu_key )
14793 data_dict [visu_key ] = pred_wind [i ]
14894
14995 save_plot_weather_from_dict (
150- foldername = args . output_path ,
96+ foldername = cfg . output_dir ,
15197 data_dict = data_dict ,
15298 visu_keys = visu_keys ,
15399 xticks = np .linspace (0 , 407 , 7 ),
@@ -159,7 +105,15 @@ def main():
159105 colorbar_label = "m/s" ,
160106 num_timestamps = 12 , # only plot 12 timestamps
161107 )
162- logger .info (f"Save plot to { args .output_path } " )
108+ logger .info (f"Save plot to { cfg .output_dir } " )
109+
110+
111+ @hydra .main (version_base = None , config_path = "./conf" , config_name = "yinglong_12.yaml" )
112+ def main (cfg : DictConfig ):
113+ if cfg .mode == "infer" :
114+ inference (cfg )
115+ else :
116+ raise ValueError (f"cfg.mode should in ['infer'], but got '{ cfg .mode } '" )
163117
164118
165119if __name__ == "__main__" :
0 commit comments