-
Notifications
You must be signed in to change notification settings - Fork 184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Refine] Refine yinglong code #786
Changes from 4 commits
89c251c
b4e59a6
e7da995
99a4d67
9ec634b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,20 +46,22 @@ wget https://paddle-org.bj.bcebos.com/paddlescience/datasets/yinglong/hrrr_examp | |
tar -xvf hrrr_example_24vars.tar | ||
wget https://paddle-org.bj.bcebos.com/paddlescience/datasets/yinglong/hrrr_example_69vars.tar | ||
tar -xvf hrrr_example_69vars.tar | ||
wget https://paddle-org.bj.bcebos.com/paddlescience/models/yinglong/yinglong_models.tar | ||
tar -xvf yinglong_models.tar | ||
wget https://paddle-org.bj.bcebos.com/paddlescience/models/yinglong/inference.tar | ||
tar -xvf inference.tar | ||
``` | ||
|
||
### 2. Run the code | ||
|
||
The following code runs the Yinglong model, and the model output will be saved in 'output/result.npy'. | ||
|
||
``` shell | ||
cd PaddleScience | ||
export PYTHONPATH=$PWD | ||
cd ./examples/yinglong | ||
# YingLong-12 Layers | ||
python examples/yinglong/predict_12layers.py | ||
python ./predict_12layers.py mode=infer | ||
# YingLong-24 Layers | ||
# python examples/yinglong/predict_24layers.py | ||
python ./predict_24layers.py mode=infer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已删除 |
||
``` | ||
|
||
We also visualized the predicted wind speed at 10 meters above ground level, with an initial field of 0:00 on January 1, 2022. Click [here](https://paddle-org.bj.bcebos.com/paddlescience/docs/Yinglong/result.gif) to view the prediction results. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
hydra: | ||
run: | ||
# dynamic output directory according to running time and override name | ||
# dir: outputs_yinglong/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} | ||
dir: ./outputs_yinglong_12 | ||
job: | ||
name: ${mode} # name of logfile | ||
chdir: false # keep current working direcotry unchaned | ||
config: | ||
override_dirname: | ||
exclude_keys: | ||
- TRAIN.checkpoint_path | ||
- TRAIN.pretrained_model_path | ||
- EVAL.pretrained_model_path | ||
- INFER.pretrained_model_path | ||
- mode | ||
- output_dir | ||
- log_freq | ||
callbacks: | ||
init_callback: | ||
_target_: ppsci.utils.callbacks.InitCallback | ||
sweep: | ||
# output directory for multirun | ||
dir: ${hydra.run.dir} | ||
subdir: ./ | ||
|
||
# general settings | ||
mode: train # running mode: train/eval | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mode默认值可以改成infer,默认值为目前还不支持的train有点奇怪 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已改为infer |
||
seed: 2023 | ||
output_dir: ${hydra:run.dir} | ||
log_freq: 20 | ||
|
||
# inference settings | ||
INFER: | ||
pretrained_model_path: null | ||
export_path: ./inference/yinglong_12 | ||
pdmodel_path: ${INFER.export_path}.pdmodel | ||
pdpiparams_path: ${INFER.export_path}.pdiparams | ||
onnx_path: ${INFER.export_path}.onnx | ||
device: gpu | ||
engine: native | ||
precision: fp32 | ||
ir_optim: false | ||
min_subgraph_size: 30 | ||
gpu_mem: 100 | ||
gpu_id: 0 | ||
max_batch_size: 1 | ||
num_cpu_threads: 10 | ||
batch_size: 1 | ||
mean_path: ./hrrr_example_24vars/stat/mean_crop.npy | ||
std_path: ./hrrr_example_24vars/stat/std_crop.npy | ||
input_file: ./hrrr_example_24vars/valid/2022/01/01.h5 | ||
init_time: 2022/01/01/00 | ||
nwp_file: ./hrrr_example_24vars/nwp_convert/2022/01/01/00.h5 | ||
num_timestamps: 22 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
hydra: | ||
run: | ||
# dynamic output directory according to running time and override name | ||
# dir: outputs_yinglong/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} | ||
dir: ./outputs_yinglong_24 | ||
job: | ||
name: ${mode} # name of logfile | ||
chdir: false # keep current working direcotry unchaned | ||
config: | ||
override_dirname: | ||
exclude_keys: | ||
- TRAIN.checkpoint_path | ||
- TRAIN.pretrained_model_path | ||
- EVAL.pretrained_model_path | ||
- INFER.pretrained_model_path | ||
- mode | ||
- output_dir | ||
- log_freq | ||
callbacks: | ||
init_callback: | ||
_target_: ppsci.utils.callbacks.InitCallback | ||
sweep: | ||
# output directory for multirun | ||
dir: ${hydra.run.dir} | ||
subdir: ./ | ||
|
||
# general settings | ||
mode: train # running mode: train/eval | ||
seed: 2023 | ||
output_dir: ${hydra:run.dir} | ||
log_freq: 20 | ||
|
||
# inference settings | ||
INFER: | ||
pretrained_model_path: null | ||
export_path: ./inference/yinglong_24 | ||
pdmodel_path: ${INFER.export_path}.pdmodel | ||
pdpiparams_path: ${INFER.export_path}.pdiparams | ||
onnx_path: ${INFER.export_path}.onnx | ||
device: gpu | ||
engine: native | ||
precision: fp32 | ||
ir_optim: false | ||
min_subgraph_size: 30 | ||
gpu_mem: 100 | ||
gpu_id: 0 | ||
max_batch_size: 1 | ||
num_cpu_threads: 10 | ||
batch_size: 1 | ||
mean_path: ./hrrr_example_69vars/stat/mean_crop.npy | ||
std_path: ./hrrr_example_69vars/stat/std_crop.npy | ||
input_file: ./hrrr_example_69vars/valid/2022/01/01.h5 | ||
init_time: 2022/01/01/00 | ||
nwp_file: ./hrrr_example_69vars/nwp_convert/2022/01/01/00.h5 | ||
num_timestamps: 22 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,74 +12,22 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
from os import path as osp | ||
|
||
import h5py | ||
import hydra | ||
import numpy as np | ||
import paddle | ||
import pandas as pd | ||
from omegaconf import DictConfig | ||
from packaging import version | ||
|
||
from examples.yinglong.plot import save_plot_weather_from_dict | ||
from examples.yinglong.predictor import YingLong | ||
from examples.yinglong.predictor import YingLongPredictor | ||
from ppsci.utils import logger | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--model_file", | ||
type=str, | ||
default="./yinglong_models/yinglong_12.pdmodel", | ||
help="Model filename", | ||
) | ||
parser.add_argument( | ||
"--params_file", | ||
type=str, | ||
default="./yinglong_models/yinglong_12.pdiparams", | ||
help="Parameter filename", | ||
) | ||
parser.add_argument( | ||
"--mean_path", | ||
type=str, | ||
default="./hrrr_example_24vars/stat/mean_crop.npy", | ||
help="Mean filename", | ||
) | ||
parser.add_argument( | ||
"--std_path", | ||
type=str, | ||
default="./hrrr_example_24vars/stat/std_crop.npy", | ||
help="Standard deviation filename", | ||
) | ||
parser.add_argument( | ||
"--input_file", | ||
type=str, | ||
default="./hrrr_example_24vars/valid/2022/01/01.h5", | ||
help="Input filename", | ||
) | ||
parser.add_argument( | ||
"--init_time", type=str, default="2022/01/01/00", help="Init time" | ||
) | ||
parser.add_argument( | ||
"--nwp_file", | ||
type=str, | ||
default="./hrrr_example_24vars/nwp_convert/2022/01/01/00.h5", | ||
help="NWP filename", | ||
) | ||
parser.add_argument( | ||
"--num_timestamps", type=int, default=22, help="Number of timestamps" | ||
) | ||
parser.add_argument( | ||
"--output_path", type=str, default="output", help="Output file path" | ||
) | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
logger.init_logger("ppsci", osp.join(args.output_path, "predict.log"), "info") | ||
def inference(cfg: DictConfig): | ||
# log paddlepaddle's version | ||
if version.Version(paddle.__version__) != version.Version("0.0.0"): | ||
paddle_version = paddle.__version__ | ||
|
@@ -93,19 +41,20 @@ def main(): | |
|
||
logger.info(f"Using paddlepaddle {paddle_version}") | ||
|
||
num_timestamps = args.num_timestamps | ||
num_timestamps = cfg.INFER.num_timestamps | ||
# create predictor | ||
predictor = YingLong( | ||
args.model_file, args.params_file, args.mean_path, args.std_path | ||
) | ||
# predictor = YingLong( | ||
# args.model_file, args.params_file, args.mean_path, args.std_path | ||
# ) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 之前的代码可以删掉啦 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已删除 |
||
predictor = YingLongPredictor(cfg) | ||
|
||
# load data | ||
# HRRR Crop use 24 atmospheric variable,their index in the dataset is from 0 to 23. | ||
# The variable name is 'z50', 'z500', 'z850', 'z1000', 't50', 't500', 't850', 'z1000', | ||
# 's50', 's500', 's850', 's1000', 'u50', 'u500', 'u850', 'u1000', 'v50', 'v500', | ||
# 'v850', 'v1000', 'mslp', 'u10', 'v10', 't2m'. | ||
input_file = h5py.File(args.input_file, "r")["fields"] | ||
nwp_file = h5py.File(args.nwp_file, "r")["fields"] | ||
input_file = h5py.File(cfg.INFER.input_file, "r")["fields"] | ||
nwp_file = h5py.File(cfg.INFER.nwp_file, "r")["fields"] | ||
|
||
# input_data.shape: (1, 24, 440, 408) | ||
input_data = input_file[0:1] | ||
|
@@ -115,18 +64,18 @@ def main(): | |
ground_truth = input_file[1 : num_timestamps + 1] | ||
|
||
# create time stamps | ||
cur_time = pd.to_datetime(args.init_time, format="%Y/%m/%d/%H") | ||
cur_time = pd.to_datetime(cfg.INFER.init_time, format="%Y/%m/%d/%H") | ||
time_stamps = [[cur_time]] | ||
for _ in range(num_timestamps): | ||
cur_time += pd.Timedelta(hours=1) | ||
time_stamps.append([cur_time]) | ||
|
||
# run predictor | ||
pred_data = predictor(input_data, time_stamps, nwp_data) | ||
pred_data = predictor.predict(input_data, time_stamps, nwp_data) | ||
pred_data = pred_data.squeeze(axis=1) # (num_timestamps, 24, 440, 408) | ||
|
||
# save predict data | ||
save_path = osp.join(args.output_path, "result.npy") | ||
save_path = osp.join(cfg.output_dir, "result.npy") | ||
np.save(save_path, pred_data) | ||
logger.info(f"Save output to {save_path}") | ||
|
||
|
@@ -139,15 +88,15 @@ def main(): | |
data_dict = {} | ||
visu_keys = [] | ||
for i in range(num_timestamps): | ||
visu_key = f"Init time: {args.init_time}h\n Ground truth: {i+1}h" | ||
visu_key = f"Init time: {cfg.INFER.init_time}h\n Ground truth: {i+1}h" | ||
visu_keys.append(visu_key) | ||
data_dict[visu_key] = ground_truth_wind[i] | ||
visu_key = f"Init time: {args.init_time}h\n YingLong-12 Layers: {i+1}h" | ||
visu_key = f"Init time: {cfg.INFER.init_time}h\n YingLong-12 Layers: {i+1}h" | ||
visu_keys.append(visu_key) | ||
data_dict[visu_key] = pred_wind[i] | ||
|
||
save_plot_weather_from_dict( | ||
foldername=args.output_path, | ||
foldername=cfg.output_dir, | ||
data_dict=data_dict, | ||
visu_keys=visu_keys, | ||
xticks=np.linspace(0, 407, 7), | ||
|
@@ -159,7 +108,15 @@ def main(): | |
colorbar_label="m/s", | ||
num_timestamps=12, # only plot 12 timestamps | ||
) | ||
logger.info(f"Save plot to {args.output_path}") | ||
logger.info(f"Save plot to {cfg.output_dir}") | ||
|
||
|
||
@hydra.main(version_base=None, config_path="./conf", config_name="yinglong_12.yaml") | ||
def main(cfg: DictConfig): | ||
if cfg.mode == "infer": | ||
inference(cfg) | ||
else: | ||
raise ValueError(f"cfg.mode should in ['infer'], but got '{cfg.mode}'") | ||
|
||
|
||
if __name__ == "__main__": | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要这个“./”吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除