diff --git a/docs/zh/examples/phylstm.md b/docs/zh/examples/phylstm.md index 3d4e0922d..da465a6c2 100644 --- a/docs/zh/examples/phylstm.md +++ b/docs/zh/examples/phylstm.md @@ -44,6 +44,44 @@ python phylstm3.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/phylstm/phylstm3_pretrained.pdparams ``` +=== "模型导出命令" + + === "phylstm2" + + ``` sh + python phylstm2.py mode=export + ``` + + === "phylstm3" + + ``` sh + python phylstm3.py mode=export + ``` + +=== "模型推理命令" + + === "phylstm2" + + ``` sh + # linux + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyLSTM/data_boucwen.mat + # windows + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyLSTM/data_boucwen.mat -o data_boucwen.mat + python phylstm2.py mode=infer + ``` + + === "phylstm3" + + ``` sh + # linux + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyLSTM/data_boucwen.mat + # windows + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyLSTM/data_boucwen.mat -o data_boucwen.mat + python phylstm3.py mode=infer + ``` + + + | 预训练模型 | 指标 | |:--| :--| | [phylstm2_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/phylstm/phylstm2_pretrained.pdparams) | loss(sup_valid): 0.00799 | diff --git a/examples/phylstm/conf/phylstm2.yaml b/examples/phylstm/conf/phylstm2.yaml index d896194f1..40efa8e26 100644 --- a/examples/phylstm/conf/phylstm2.yaml +++ b/examples/phylstm/conf/phylstm2.yaml @@ -48,3 +48,21 @@ TRAIN: EVAL: pretrained_model_path: null eval_with_no_grad: true + +# inference settings +INFER: + pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/phylstm/phylstm2_pretrained.pdparams + export_path: ./inference/phylstm2 + pdmodel_path: ${INFER.export_path}.pdmodel + pdiparams_path: ${INFER.export_path}.pdiparams + onnx_path: ${INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 10240 + num_cpu_threads: 10 + batch_size: 10240 diff --git a/examples/phylstm/conf/phylstm3.yaml b/examples/phylstm/conf/phylstm3.yaml index 222917d35..c27c21a93 100644 --- a/examples/phylstm/conf/phylstm3.yaml +++ b/examples/phylstm/conf/phylstm3.yaml @@ -48,3 +48,21 @@ TRAIN: EVAL: pretrained_model_path: null eval_with_no_grad: true + +# inference settings +INFER: + pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/phylstm/phylstm3_pretrained.pdparams + export_path: ./inference/phylstm3 + pdmodel_path: ${INFER.export_path}.pdmodel + pdiparams_path: ${INFER.export_path}.pdiparams + onnx_path: ${INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 10240 + num_cpu_threads: 10 + batch_size: 10240 diff --git a/examples/phylstm/phylstm2.py b/examples/phylstm/phylstm2.py index 0d320e77f..dd7e5ff3a 100755 --- a/examples/phylstm/phylstm2.py +++ b/examples/phylstm/phylstm2.py @@ -308,14 +308,144 @@ def evaluate(cfg: DictConfig): solver.eval() +def export(cfg: DictConfig): + mat = scipy.io.loadmat(cfg.DATA_FILE_PATH) + u_data = mat["target_X_tf"] + u_data = u_data.reshape([u_data.shape[0], u_data.shape[1], 1]) + u_all = u_data + eta_star = u_all[0:10] + eta = eta_star + # set model + model = ppsci.arch.DeepPhyLSTM( + cfg.MODEL.input_size, + eta.shape[2], + cfg.MODEL.hidden_size, + cfg.MODEL.model_type, + ) + # initialize solver + solver = ppsci.solver.Solver( + model, + pretrained_model_path=cfg.INFER.pretrained_model_path, + ) + # export model + from paddle.static import InputSpec + + input_spec = [ + {key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys}, + ] + solver.export(input_spec, cfg.INFER.export_path) + + +def inference(cfg: DictConfig): + from deploy.python_infer import pinn_predictor + + predictor = pinn_predictor.PINNPredictor(cfg) + + mat = scipy.io.loadmat(cfg.DATA_FILE_PATH) + ag_data = mat["input_tf"] # ag, ad, av + u_data = mat["target_X_tf"] + ut_data = mat["target_Xd_tf"] + utt_data = mat["target_Xdd_tf"] + ag_data = ag_data.reshape([ag_data.shape[0], ag_data.shape[1], 1]) + u_data = u_data.reshape([u_data.shape[0], u_data.shape[1], 1]) + ut_data = ut_data.reshape([ut_data.shape[0], ut_data.shape[1], 1]) + utt_data = utt_data.reshape([utt_data.shape[0], utt_data.shape[1], 1]) + + t = mat["time"] + dt = t[0, 1] - t[0, 0] + + ag_all = ag_data + u_all = u_data + u_t_all = ut_data + u_tt_all = utt_data + + # finite difference + N = u_data.shape[1] + phi1 = np.concatenate( + [ + np.array([-3 / 2, 2, -1 / 2]), + np.zeros([N - 3]), + ] + ) + temp1 = np.concatenate([-1 / 2 * np.identity(N - 2), np.zeros([N - 2, 2])], axis=1) + temp2 = np.concatenate([np.zeros([N - 2, 2]), 1 / 2 * np.identity(N - 2)], axis=1) + phi2 = temp1 + temp2 + phi3 = np.concatenate( + [ + np.zeros([N - 3]), + np.array([1 / 2, -2, 3 / 2]), + ] + ) + phi_t0 = ( + 1 + / dt + * np.concatenate( + [ + np.reshape(phi1, [1, phi1.shape[0]]), + phi2, + np.reshape(phi3, [1, phi3.shape[0]]), + ], + axis=0, + ) + ) + phi_t0 = np.reshape(phi_t0, [1, N, N]) + + ag_star = ag_all[0:10] + eta_star = u_all[0:10] + eta_t_star = u_t_all[0:10] + eta_tt_star = u_tt_all[0:10] + ag_c_star = ag_all[0:50] + lift_star = -ag_c_star + + eta = eta_star + ag = ag_star + lift = lift_star + eta_t = eta_t_star + eta_tt = eta_tt_star + ag_c = ag_c_star + g = -eta_tt - ag + phi_t = np.repeat(phi_t0, ag_c_star.shape[0], axis=0) + + input_dict = { + "eta": eta, + "eta_t": eta_t, + "g": g, + "ag": ag, + "ag_c": ag_c, + "lift": lift, + "phi_t": phi_t, + } + + output_dict = predictor.predict(input_dict, cfg.INFER.batch_size) + + # mapping data to cfg.INFER.output_keys + output_dict = { + store_key: output_dict[infer_key] + for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys()) + } + + ppsci.visualize.save_vtu_from_dict( + "./phylstm2_pred.vtu", + {**input_dict, **output_dict}, + input_dict.keys(), + cfg.MODEL.output_keys, + ) + + @hydra.main(version_base=None, config_path="./conf", config_name="phylstm2.yaml") def main(cfg: DictConfig): if cfg.mode == "train": train(cfg) elif cfg.mode == "eval": evaluate(cfg) + elif cfg.mode == "export": + export(cfg) + elif cfg.mode == "infer": + inference(cfg) else: - raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + raise ValueError( + f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'" + ) if __name__ == "__main__": diff --git a/examples/phylstm/phylstm3.py b/examples/phylstm/phylstm3.py index 071ecbeed..fd12721f0 100755 --- a/examples/phylstm/phylstm3.py +++ b/examples/phylstm/phylstm3.py @@ -346,14 +346,151 @@ def evaluate(cfg: DictConfig): solver.eval() +def export(cfg: DictConfig): + mat = scipy.io.loadmat(cfg.DATA_FILE_PATH) + u_data = mat["target_X_tf"] + u_data = u_data.reshape([u_data.shape[0], u_data.shape[1], 1]) + u_all = u_data + eta_star = u_all[0:10] + eta = eta_star + # set model + model = ppsci.arch.DeepPhyLSTM( + cfg.MODEL.input_size, + eta.shape[2], + cfg.MODEL.hidden_size, + cfg.MODEL.model_type, + ) + # initialize solver + solver = ppsci.solver.Solver( + model, + pretrained_model_path=cfg.INFER.pretrained_model_path, + ) + # export model + from paddle.static import InputSpec + + input_spec = [ + {key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys}, + ] + solver.export(input_spec, cfg.INFER.export_path) + + +def inference(cfg: DictConfig): + from deploy.python_infer import pinn_predictor + + predictor = pinn_predictor.PINNPredictor(cfg) + + mat = scipy.io.loadmat(cfg.DATA_FILE_PATH) + t = mat["time"] + dt = 0.02 + n1 = int(dt / 0.005) + t = t[::n1] + + ag_data = mat["input_tf"][:, ::n1] # ag, ad, av + u_data = mat["target_X_tf"][:, ::n1] + ut_data = mat["target_Xd_tf"][:, ::n1] + utt_data = mat["target_Xdd_tf"][:, ::n1] + ag_data = ag_data.reshape([ag_data.shape[0], ag_data.shape[1], 1]) + u_data = u_data.reshape([u_data.shape[0], u_data.shape[1], 1]) + ut_data = ut_data.reshape([ut_data.shape[0], ut_data.shape[1], 1]) + utt_data = utt_data.reshape([utt_data.shape[0], utt_data.shape[1], 1]) + + ag_pred = mat["input_pred_tf"][:, ::n1] # ag, ad, av + u_pred = mat["target_pred_X_tf"][:, ::n1] + ut_pred = mat["target_pred_Xd_tf"][:, ::n1] + utt_pred = mat["target_pred_Xdd_tf"][:, ::n1] + ag_pred = ag_pred.reshape([ag_pred.shape[0], ag_pred.shape[1], 1]) + u_pred = u_pred.reshape([u_pred.shape[0], u_pred.shape[1], 1]) + ut_pred = ut_pred.reshape([ut_pred.shape[0], ut_pred.shape[1], 1]) + utt_pred = utt_pred.reshape([utt_pred.shape[0], utt_pred.shape[1], 1]) + + N = u_data.shape[1] + phi1 = np.concatenate( + [ + np.array([-3 / 2, 2, -1 / 2]), + np.zeros([N - 3]), + ] + ) + temp1 = np.concatenate([-1 / 2 * np.identity(N - 2), np.zeros([N - 2, 2])], axis=1) + temp2 = np.concatenate([np.zeros([N - 2, 2]), 1 / 2 * np.identity(N - 2)], axis=1) + phi2 = temp1 + temp2 + phi3 = np.concatenate( + [ + np.zeros([N - 3]), + np.array([1 / 2, -2, 3 / 2]), + ] + ) + phi_t0 = ( + 1 + / dt + * np.concatenate( + [ + np.reshape(phi1, [1, phi1.shape[0]]), + phi2, + np.reshape(phi3, [1, phi3.shape[0]]), + ], + axis=0, + ) + ) + phi_t0 = np.reshape(phi_t0, [1, N, N]) + + ag_star = ag_data + eta_star = u_data + eta_t_star = ut_data + eta_tt_star = utt_data + ag_c_star = np.concatenate([ag_data, ag_pred[0:53]]) + lift_star = -ag_c_star + + eta = eta_star + ag = ag_star + lift = lift_star + eta_t = eta_t_star + eta_tt = eta_tt_star + g = -eta_tt - ag + ag_c = ag_c_star + + phi_t = np.repeat(phi_t0, ag_c_star.shape[0], axis=0) + + input_dict = { + "eta": eta, + "eta_t": eta_t, + "eta_tt": eta_tt, + "g": g, + "ag": ag, + "ag_c": ag_c, + "lift": lift, + "phi_t": phi_t, + } + + output_dict = predictor.predict(input_dict, cfg.INFER.batch_size) + + # mapping data to cfg.INFER.output_keys + output_dict = { + store_key: output_dict[infer_key] + for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys()) + } + + ppsci.visualize.save_vtu_from_dict( + "./phylstm3_pred.vtu", + {**input_dict, **output_dict}, + input_dict.keys(), + cfg.MODEL.output_keys, + ) + + @hydra.main(version_base=None, config_path="./conf", config_name="phylstm3.yaml") def main(cfg: DictConfig): if cfg.mode == "train": train(cfg) elif cfg.mode == "eval": evaluate(cfg) + elif cfg.mode == "export": + export(cfg) + elif cfg.mode == "infer": + inference(cfg) else: - raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + raise ValueError( + f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'" + ) if __name__ == "__main__":