diff --git a/docs/zh/examples/phygeonet.md b/docs/zh/examples/phygeonet.md index 5daf182cfb..9d51a12604 100644 --- a/docs/zh/examples/phygeonet.md +++ b/docs/zh/examples/phygeonet.md @@ -53,6 +53,40 @@ ``` +=== "模型导出命令" + + ``` sh + # heat_equation + python heat_equation.py mode=export + + # heat_equation_bc + python heat_equation_with_bc.py mode=export + ``` + +=== "模型推理命令" + + ``` sh + # heat_equation + # linux + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyGeoNet/heat_equation.npz -P ./data/ + + # windows + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyGeoNet/heat_equation.npz --create-dirs -o ./data/heat_equation.npz + + python heat_equation.py mode=infer + + # heat_equation_bc + # linux + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyGeoNet/heat_equation_bc.npz -P ./data/ + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyGeoNet/heat_equation_bc_test.npz -P ./data/ + + # windows + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyGeoNet/heat_equation_bc.npz --create-dirs -o ./data/heat_equation.npz + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyGeoNet/heat_equation_bc_test.npz --create-dirs -o ./data/heat_equation.npz + + python heat_equation_with_bc.py mode=infer + ``` + | 模型 | mRes | ev | | :-- | :-- | :-- | | [heat_equation_pretrain.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/PhyGeoNet/heat_equation_pretrain.pdparams) | 0.815 |0.095| diff --git a/examples/phygeonet/conf/heat_equation.yaml b/examples/phygeonet/conf/heat_equation.yaml index 63fdc4ebe3..d98a08807a 100644 --- a/examples/phygeonet/conf/heat_equation.yaml +++ b/examples/phygeonet/conf/heat_equation.yaml @@ -51,3 +51,21 @@ TRAIN: EVAL: pretrained_model_path: null eval_with_no_grad: true + +# inference settings +INFER: + pretrained_model_path: 'https://paddle-org.bj.bcebos.com/paddlescience/models/PhyGeoNet/heat_equation_pretrain.pdparams' + export_path: ./inference/heat_equation + pdmodel_path: ${INFER.export_path}.pdmodel + pdiparams_path: ${INFER.export_path}.pdiparams + onnx_path: ${INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 20 + gpu_id: 0 + max_batch_size: 256 + num_cpu_threads: 10 + batch_size: 256 diff --git a/examples/phygeonet/conf/heat_equation_with_bc.yaml b/examples/phygeonet/conf/heat_equation_with_bc.yaml index af92466c7b..6bf392d6cb 100644 --- a/examples/phygeonet/conf/heat_equation_with_bc.yaml +++ b/examples/phygeonet/conf/heat_equation_with_bc.yaml @@ -24,6 +24,7 @@ hydra: # general settings mode: train # running mode: train/eval +log_freq: 50 seed: 66 data_dir: ./data/heat_equation_bc.npz test_data_dir: ./data/heat_equation_bc_test.npz @@ -53,3 +54,21 @@ TRAIN: EVAL: pretrained_model_path: null eval_with_no_grad: true + +# inference settings +INFER: + pretrained_model_path: 'https://paddle-org.bj.bcebos.com/paddlescience/models/PhyGeoNet/heat_equation_bc_pretrain.pdparams' + export_path: ./inference/heat_equation_bc + pdmodel_path: ${INFER.export_path}.pdmodel + pdiparams_path: ${INFER.export_path}.pdiparams + onnx_path: ${INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 20 + gpu_id: 0 + max_batch_size: 256 + num_cpu_threads: 10 + batch_size: 256 diff --git a/examples/phygeonet/heat_equation.py b/examples/phygeonet/heat_equation.py index ba7869b6e2..19cb981640 100644 --- a/examples/phygeonet/heat_equation.py +++ b/examples/phygeonet/heat_equation.py @@ -1,3 +1,4 @@ +import os.path as osp from typing import Dict import hydra @@ -153,14 +154,102 @@ def evaluate(cfg: DictConfig): plt.close(fig) +def export(cfg: DictConfig): + model = ppsci.arch.USCNN(**cfg.MODEL) + # initialize solver + solver = ppsci.solver.Solver( + model, + pretrained_model_path=cfg.INFER.pretrained_model_path, + ) + # export model + from paddle.static import InputSpec + + input_spec = [ + { + key: InputSpec([None, 2, 19, 84], "float32", name=key) + for key in model.input_keys + }, + ] + solver.export(input_spec, cfg.INFER.export_path) + + +def inference(cfg: DictConfig): + from deploy.python_infer import pinn_predictor + + predictor = pinn_predictor.PINNPredictor(cfg) + data = np.load(cfg.data_dir) + coords = data["coords"] + ofv_sb = data["OFV_sb"] + + ## create model + pad_singleside = cfg.MODEL.pad_singleside + input_spec = {"coords": coords} + + output_v = predictor.predict(input_spec, cfg.INFER.batch_size) + # mapping data to cfg.INFER.output_keys + output_v = { + store_key: output_v[infer_key] + for store_key, infer_key in zip(cfg.MODEL.output_keys, output_v.keys()) + } + + output_v = output_v["output_v"] + + output_v[0, 0, -pad_singleside:, pad_singleside:-pad_singleside] = 0 + output_v[0, 0, :pad_singleside, pad_singleside:-pad_singleside] = 1 + output_v[0, 0, pad_singleside:-pad_singleside, -pad_singleside:] = 1 + output_v[0, 0, pad_singleside:-pad_singleside, 0:pad_singleside] = 1 + output_v[0, 0, 0, 0] = 0.5 * (output_v[0, 0, 0, 1] + output_v[0, 0, 1, 0]) + output_v[0, 0, 0, -1] = 0.5 * (output_v[0, 0, 0, -2] + output_v[0, 0, 1, -1]) + + ev = paddle.sqrt( + paddle.mean((ofv_sb - output_v[0, 0]) ** 2) / paddle.mean(ofv_sb**2) + ).item() + logger.info(f"ev: {ev}") + + fig = plt.figure() + ax = plt.subplot(1, 2, 1) + utils.visualize( + ax, + coords[0, 0, 1:-1, 1:-1], + coords[0, 1, 1:-1, 1:-1], + output_v[0, 0, 1:-1, 1:-1], + "horizontal", + [0, 1], + ) + utils.set_axis_label(ax, "p") + ax.set_title("CNN " + r"$T$") + ax.set_aspect("equal") + ax = plt.subplot(1, 2, 2) + utils.visualize( + ax, + coords[0, 0, 1:-1, 1:-1], + coords[0, 1, 1:-1, 1:-1], + ofv_sb[1:-1, 1:-1], + "horizontal", + [0, 1], + ) + utils.set_axis_label(ax, "p") + ax.set_aspect("equal") + ax.set_title("FV " + r"$T$") + fig.tight_layout(pad=1) + fig.savefig(osp.join(cfg.output_dir, "result.png"), bbox_inches="tight") + plt.close(fig) + + @hydra.main(version_base=None, config_path="./conf", config_name="heat_equation.yaml") def main(cfg: DictConfig): if cfg.mode == "train": train(cfg) elif cfg.mode == "eval": evaluate(cfg) + elif cfg.mode == "export": + export(cfg) + elif cfg.mode == "infer": + inference(cfg) else: - raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + raise ValueError( + f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'" + ) if __name__ == "__main__": diff --git a/examples/phygeonet/heat_equation_with_bc.py b/examples/phygeonet/heat_equation_with_bc.py index 0af45be706..a03b3bee02 100644 --- a/examples/phygeonet/heat_equation_with_bc.py +++ b/examples/phygeonet/heat_equation_with_bc.py @@ -188,6 +188,119 @@ def evaluate(cfg: DictConfig): plt.close(fig1) +def export(cfg: DictConfig): + model = ppsci.arch.USCNN(**cfg.MODEL) + # initialize solver + solver = ppsci.solver.Solver( + model, + pretrained_model_path=cfg.INFER.pretrained_model_path, + ) + # export model + from paddle.static import InputSpec + + input_spec = [ + { + key: InputSpec([None, 1, 19, 84], "float32", name=key) + for key in model.input_keys + }, + ] + solver.export(input_spec, cfg.INFER.export_path) + + +def inference(cfg: DictConfig): + from deploy.python_infer import pinn_predictor + + predictor = pinn_predictor.PINNPredictor(cfg) + pad_singleside = cfg.MODEL.pad_singleside + + data = np.load(cfg.test_data_dir) + paras = data["paras"] + truths = data["truths"] + coords = data["coords"] + + paras = paras.reshape([paras.shape[0], 1, paras.shape[1], paras.shape[2]]) + input_spec = {"coords": paras} + output_v = predictor.predict(input_spec, cfg.INFER.batch_size) + # mapping data to cfg.INFER.output_keys + output_v = { + store_key: output_v[infer_key] + for store_key, infer_key in zip(cfg.MODEL.output_keys, output_v.keys()) + } + output_v = output_v["output_v"] + num_sample = output_v.shape[0] + for j in range(num_sample): + # Impose BC + output_v[j, 0, -pad_singleside:, pad_singleside:-pad_singleside] = output_v[ + j, 0, 1:2, pad_singleside:-pad_singleside + ] + output_v[j, 0, :pad_singleside, pad_singleside:-pad_singleside] = output_v[ + j, 0, -2:-1, pad_singleside:-pad_singleside + ] + output_v[j, 0, :, -pad_singleside:] = 0 + output_v[j, 0, :, 0:pad_singleside] = paras[j, 0, 0, 0] + + error = paddle.sqrt( + paddle.mean((truths - output_v) ** 2) / paddle.mean(truths**2) + ).item() + logger.info(f"The average error: {error / num_sample}") + + output_vs = output_v + PARALIST = [1, 2, 3, 4, 5, 6, 7] + for i in range(len(PARALIST)): + truth = truths[i] + coord = coords[i] + output_v = output_vs[i] + truth = truth.reshape(1, 1, truth.shape[0], truth.shape[1]) + coord = coord.reshape(1, 2, coord.shape[2], coord.shape[3]) + fig1 = plt.figure() + xylabelsize = 20 + xytickssize = 20 + titlesize = 20 + ax = plt.subplot(1, 2, 1) + _, cbar = utils.visualize( + ax, + coord[0, 0, :, :], + coord[0, 1, :, :], + output_v[0, :, :], + "horizontal", + [0, max(PARALIST)], + ) + ax.set_aspect("equal") + utils.set_axis_label(ax, "p") + ax.set_title("PhyGeoNet " + r"$T$", fontsize=titlesize) + ax.set_xlabel(xlabel=r"$x$", fontsize=xylabelsize) + ax.set_ylabel(ylabel=r"$y$", fontsize=xylabelsize) + ax.set_xticks([-1, 0, 1]) + ax.set_yticks([-1, 0, 1]) + ax.tick_params(axis="x", labelsize=xytickssize) + ax.tick_params(axis="y", labelsize=xytickssize) + cbar.set_ticks([0, 1, 2, 3, 4, 5, 6, 7]) + cbar.ax.tick_params(labelsize=xytickssize) + ax = plt.subplot(1, 2, 2) + _, cbar = utils.visualize( + ax, + coord[0, 0, :, :], + coord[0, 1, :, :], + truth[0, 0, :, :], + "horizontal", + [0, max(PARALIST)], + ) + ax.set_aspect("equal") + utils.set_axis_label(ax, "p") + ax.set_title("FV " + r"$T$", fontsize=titlesize) + ax.set_xlabel(xlabel=r"$x$", fontsize=xylabelsize) + ax.set_ylabel(ylabel=r"$y$", fontsize=xylabelsize) + ax.set_xticks([-1, 0, 1]) + ax.set_yticks([-1, 0, 1]) + ax.tick_params(axis="x", labelsize=xytickssize) + ax.tick_params(axis="y", labelsize=xytickssize) + cbar.set_ticks([0, 1, 2, 3, 4, 5, 6, 7]) + cbar.ax.tick_params(labelsize=xytickssize) + fig1.tight_layout(pad=1) + fig1.savefig(osp.join(cfg.output_dir, f"Para{i}T.png"), bbox_inches="tight") + plt.close(fig1) + + @hydra.main( version_base=None, config_path="./conf", config_name="heat_equation_with_bc.yaml" ) @@ -196,8 +309,14 @@ def main(cfg: DictConfig): train(cfg) elif cfg.mode == "eval": evaluate(cfg) + elif cfg.mode == "export": + export(cfg) + elif cfg.mode == "infer": + inference(cfg) else: - raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + raise ValueError( + f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'" + ) if __name__ == "__main__":