@@ -65,16 +65,16 @@ def inference(cfg: DictConfig):
6565 # The variable name is 'z50', 'z500', 'z850', 'z1000', 't50', 't500', 't850', 'z1000',
6666 # 's50', 's500', 's850', 's1000', 'u50', 'u500', 'u850', 'u1000', 'v50', 'v500',
6767 # 'v850', 'v1000', 'mslp', 'u10', 'v10', 't2m'.
68- input_file = read_h5py (cfg .INFER .input_file )
69- nwp_file = read_h5py (cfg .INFER .nwp_file )
70- geo_file = read_h5py (cfg .INFER .geo_file )
68+ input_data = read_h5py (cfg .INFER .input_file )
69+ nwp_data = read_h5py (cfg .INFER .nwp_file )
70+ geo_data = read_h5py (cfg .INFER .geo_file )
7171
7272 # input_data.shape: (1, 24, 440, 408)
73- input_data = input_file [0 :1 ]
73+ input_data_0 = input_data [0 :1 ]
7474 # nwp_data.shape: # (num_timestamps, 24, 440, 408)
75- nwp_data = nwp_file [0 :num_timestamps ]
75+ nwp_data = nwp_data [0 :num_timestamps ]
7676 # ground_truth.shape: (num_timestamps, 24, 440, 408)
77- ground_truth = input_file [1 : num_timestamps + 1 ]
77+ ground_truth = input_data [1 : num_timestamps + 1 ]
7878
7979 # create time stamps
8080 cur_time = pd .to_datetime (cfg .INFER .init_time , format = "%Y/%m/%d/%H" )
@@ -84,7 +84,7 @@ def inference(cfg: DictConfig):
8484 time_stamps .append ([cur_time ])
8585
8686 # run predictor
87- pred_data = predictor .predict (input_data , time_stamps , nwp_data , geo_file )
87+ pred_data = predictor .predict (input_data_0 , time_stamps , nwp_data , geo_data )
8888 pred_data = pred_data .squeeze (axis = 1 ) # (num_timestamps, 24, 440, 408)
8989
9090 # save predict data
0 commit comments