From b1aeec977ac7fbb68c76b15448015060c11fd837 Mon Sep 17 00:00:00 2001 From: Sylence8 <1656331403@qq.com> Date: Mon, 25 Nov 2024 18:01:40 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90SCU=E3=80=91=E3=80=90PPSCI=20Export&In?= =?UTF-8?q?fer=20No.19=E3=80=91phycrnet?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/zh/examples/phycrnet.md | 21 +++++ examples/phycrnet/conf/burgers_equations.yaml | 18 ++++ .../conf/fitzhugh_nagumo_RD_equation.yaml | 18 ++++ .../conf/lambda_omega_RD_equation.yaml | 18 ++++ examples/phycrnet/main.py | 87 ++++++++++++++++++- 5 files changed, 161 insertions(+), 1 deletion(-) diff --git a/docs/zh/examples/phycrnet.md b/docs/zh/examples/phycrnet.md index 681e8035b..4dac3251a 100644 --- a/docs/zh/examples/phycrnet.md +++ b/docs/zh/examples/phycrnet.md @@ -24,6 +24,27 @@ python main.py mode=eval DATA_PATH=./data/burgers_1501x2x128x128.mat EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/phycrnet/phycrnet_burgers.pdparams ``` + +=== "模型导出命令" + + ``` sh + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyCRNet/burgers_1501x2x128x128.mat -P ./data/ + # windows + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyCRNet/burgers_1501x2x128x128.mat --create-dirs -o ./data/burgers_1501x2x128x128.mat + + python main.py mode=export DATA_PATH=./data/burgers_1501x2x128x128.mat EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/phycrnet/phycrnet_burgers.pdparams + ``` + +=== "模型推理命令" + + ``` sh + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyCRNet/burgers_1501x2x128x128.mat -P ./data/ + # windows + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyCRNet/burgers_1501x2x128x128.mat --create-dirs -o ./data/burgers_1501x2x128x128.mat + + python main.py mode=infer DATA_PATH=./data/burgers_1501x2x128x128.mat EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/phycrnet/phycrnet_burgers.pdparams ``` + + | 预训练模型 | 指标 | |:--| :--| | [phycrnet_burgers_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/phycrnet/phycrnet_burgers.pdparams) | a-RMSE: 3.20e-3 | diff --git a/examples/phycrnet/conf/burgers_equations.yaml b/examples/phycrnet/conf/burgers_equations.yaml index e3d95cf23..b20970345 100644 --- a/examples/phycrnet/conf/burgers_equations.yaml +++ b/examples/phycrnet/conf/burgers_equations.yaml @@ -63,3 +63,21 @@ EVAL: eval_with_no_grad: true TIME_BATCH_SIZE: 2000 TIME_STEPS: 2001 + +# inference settings +INFER: + pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/phycrnet/phycrnet_burgers.pdparams + export_path: ./inference/phycrnet_ide + pdmodel_path: ${INFER.export_path}.pdmodel + pdpiparams_path: ${INFER.export_path}.pdiparams + device: gpu + engine: native + precision: fp32 + onnx_path: ${INFER.export_path}.onnx + ir_optim: true + min_subgraph_size: 10 + gpu_mem: 4000 + gpu_id: 0 + max_batch_size: 64 + num_cpu_threads: 4 + batch_size: 16 diff --git a/examples/phycrnet/conf/fitzhugh_nagumo_RD_equation.yaml b/examples/phycrnet/conf/fitzhugh_nagumo_RD_equation.yaml index 281e8488b..29eb06724 100644 --- a/examples/phycrnet/conf/fitzhugh_nagumo_RD_equation.yaml +++ b/examples/phycrnet/conf/fitzhugh_nagumo_RD_equation.yaml @@ -61,3 +61,21 @@ TRAIN: EVAL: pretrained_model_path: null eval_with_no_grad: true + +# inference settings +INFER: + pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/phycrnet/phycrnet_burgers.pdparams + export_path: ./inference/phycrnet_ide + pdmodel_path: ${INFER.export_path}.pdmodel + pdpiparams_path: ${INFER.export_path}.pdiparams + device: gpu + engine: native + precision: fp32 + onnx_path: ${INFER.export_path}.onnx + ir_optim: true + min_subgraph_size: 10 + gpu_mem: 4000 + gpu_id: 0 + max_batch_size: 64 + num_cpu_threads: 4 + batch_size: 16 diff --git a/examples/phycrnet/conf/lambda_omega_RD_equation.yaml b/examples/phycrnet/conf/lambda_omega_RD_equation.yaml index b5d7aaa67..d40a589d9 100644 --- a/examples/phycrnet/conf/lambda_omega_RD_equation.yaml +++ b/examples/phycrnet/conf/lambda_omega_RD_equation.yaml @@ -61,3 +61,21 @@ TRAIN: EVAL: pretrained_model_path: null eval_with_no_grad: true + +# inference settings +INFER: + pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/phycrnet/phycrnet_burgers.pdparams + export_path: ./inference/phycrnet_ide + pdmodel_path: ${INFER.export_path}.pdmodel + pdpiparams_path: ${INFER.export_path}.pdiparams + device: gpu + engine: native + precision: fp32 + onnx_path: ${INFER.export_path}.onnx + ir_optim: true + min_subgraph_size: 10 + gpu_mem: 4000 + gpu_id: 0 + max_batch_size: 64 + num_cpu_threads: 4 + batch_size: 16 diff --git a/examples/phycrnet/main.py b/examples/phycrnet/main.py index ef0f8fab8..6176a8103 100644 --- a/examples/phycrnet/main.py +++ b/examples/phycrnet/main.py @@ -178,6 +178,85 @@ def _transform_out(_in, _out): functions.output_graph(model, input_dict_val, cfg.output_dir, cfg.case_name) +def export(cfg: DictConfig): + # set random seed for reproducibility + ppsci.utils.misc.set_random_seed(cfg.seed) + # initialize logger + logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info") + + # set initial states for convlstm + NUM_CONVLSTM = cfg.num_convlstm + (h0, c0) = (paddle.randn((1, 128, 16, 16)), paddle.randn((1, 128, 16, 16))) + initial_state = [] + for _ in range(NUM_CONVLSTM): + initial_state.append((h0, c0)) + + # grid parameters + time_steps = cfg.TIME_STEPS + dx = cfg.DX[0] / cfg.DX[1] + + steps = cfg.EVAL.TIME_BATCH_SIZE + 1 + effective_step = list(range(0, steps)) + num_time_batch = int(time_steps / cfg.EVAL.TIME_BATCH_SIZE) + + functions.dt = cfg.DT + functions.dx = dx + functions.num_time_batch = num_time_batch + model = ppsci.arch.PhyCRNet( + dt=cfg.DT, step=steps, effective_step=effective_step, **cfg.MODEL + ) + + # initialize solver + solver = ppsci.solver.Solver( + model, + pretrained_model_path=cfg.INFER.pretrained_model_path, + ) + # export model + from paddle.static import InputSpec + + input_spec = [ + { + key: InputSpec([None, 1], "float32", name=key) + for key in cfg.MODEL.input_keys + }, + ] + solver.export(input_spec, cfg.INFER.export_path, with_onnx=False) + + +def inference(cfg: DictConfig): + from deploy.python_infer import pinn_predictor + + predictor = pinn_predictor.PINNPredictor(cfg) + + # use Burgers_2d_solver_HighOrder.py to generate data + data = scio.loadmat(cfg.DATA_PATH) + uv = data["uv"] # [t,c,h,w] + functions.uv = uv + + initial_state = [] + # generate input data + (_, _, input_dict, _,) = functions.Dataset( + paddle.to_tensor(initial_state), + paddle.to_tensor( + uv[ + 0:1, + ], + dtype=paddle.get_default_dtype(), + ), + ).get() + + output_dict = predictor.predict( + {key: input_dict[key] for key in cfg.INFER.input_keys}, cfg.INFER.batch_size + ) + # mapping data to cfg.INFER.output_keys + output_dict = { + store_key: output_dict[infer_key] + for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys()) + } + + functions.output_graph(predictor, input_dict, cfg.output_dir, cfg.case_name) + + @hydra.main( version_base=None, config_path="./conf", config_name="burgers_equations.yaml" ) @@ -186,8 +265,14 @@ def main(cfg: DictConfig): train(cfg) elif cfg.mode == "eval": evaluate(cfg) + elif cfg.mode == "export": + export(cfg) + elif cfg.mode == "infer": + inference(cfg) else: - raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + raise ValueError( + f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'" + ) if __name__ == "__main__":