From 08685309a2fc46ce67fd65e36d50aa218ac0c44d Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Sat, 2 Mar 2024 21:43:08 +0800 Subject: [PATCH 1/2] add export and inference --- docs/zh/examples/topopt.md | 16 +++++ examples/topopt/conf/topopt.yaml | 31 +++++++++ examples/topopt/topopt.py | 113 ++++++++++++++++++++++++++++++- examples/topopt/topoptmodel.py | 27 +++++--- 4 files changed, 174 insertions(+), 13 deletions(-) diff --git a/docs/zh/examples/topopt.md b/docs/zh/examples/topopt.md index 1cdd1c0a1..2722b95ce 100644 --- a/docs/zh/examples/topopt.md +++ b/docs/zh/examples/topopt.md @@ -22,6 +22,22 @@ python topopt.py mode=eval 'EVAL.pretrained_model_path_dict={'Uniform': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/uniform_pretrained.pdparams', 'Poisson5': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson5_pretrained.pdparams', 'Poisson10': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson10_pretrained.pdparams', 'Poisson30': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson30_pretrained.pdparams'}' ``` +=== "模型导出命令" + + ``` sh + python topopt.py mode=export INFER.pretrained_model_name=Uniform + ``` + +=== "模型推理命令" + + ``` sh + # linux + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 -P ./datasets/ + # windows + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 --output ./datasets/top_dataset.h5 + python topopt.py mode=infer INFER.img_num=3 + ``` + | 预训练模型 | 指标 | |:--| :--| | [topopt_uniform_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/uniform_pretrained.pdparams) | loss(sup_validator): [0.14336, 0.10211, 0.07927, 0.06433, 0.04970, 0.04612, 0.04201, 0.03566, 0.03623, 0.03314, 0.02929, 0.02857, 0.02498, 0.02517, 0.02523, 0.02618]
metric.Binary_Acc(sup_validator): [0.9410, 0.9673, 0.9718, 0.9727, 0.9818, 0.9824, 0.9826, 0.9845, 0.9856, 0.9892, 0.9892, 0.9907, 0.9890, 0.9916, 0.9914, 0.9922]
metric.IoU(sup_validator): [0.8887, 0.9367, 0.9452, 0.9468, 0.9644, 0.9655, 0.9659, 0.9695, 0.9717, 0.9787, 0.9787, 0.9816, 0.9784, 0.9835, 0.9831, 0.9845] | diff --git a/examples/topopt/conf/topopt.yaml b/examples/topopt/conf/topopt.yaml index 6f6a7fbe7..6ff8b8ad8 100644 --- a/examples/topopt/conf/topopt.yaml +++ b/examples/topopt/conf/topopt.yaml @@ -14,8 +14,13 @@ hydra: - EVAL.pretrained_model_path_dict - EVAL.batch_size - EVAL.num_val_step + - EXPORT.pretrained_model_name + - INFER.pretrained_model_path_dict + - INFER.export_path + - INFER.batch_size - mode - vol_coeff + - log_freq sweep: # output directory for multirun dir: ${hydra.run.dir} @@ -25,6 +30,7 @@ hydra: mode: train # running mode: train/eval seed: 42 output_dir: ${hydra:run.dir} +log_freq: 20 # set default cases parameters CASE_PARAM: [[Poisson, 5], [Poisson, 10], [Poisson, 30], [Uniform, null]] @@ -57,3 +63,28 @@ EVAL: pretrained_model_path_dict: null # a dict: {casename1:path1, casename2:path2, casename3:path3, casename4:path4} num_val_step: 10 # the number of iteration for each evaluation case batch_size: 16 + +# inference settings +INFER: + pretrained_model_name: null # a string, indicating which model you want to export. Support [Uniform, Poisson5, Poisson10, Poisson30]. + pretrained_model_path_dict: {'Uniform': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/uniform_pretrained.pdparams', 'Poisson5': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson5_pretrained.pdparams', 'Poisson10': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson10_pretrained.pdparams', 'Poisson30': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson30_pretrained.pdparams'} + export_path: ./inference/topopt + pdmodel_path: ${INFER.export_path}.pdmodel + pdpiparams_path: ${INFER.export_path}.pdiparams + device: gpu + engine: native + precision: fp32 + onnx_path: null + ir_optim: true + min_subgraph_size: 30 + gpu_mem: 4000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + batch_size: 4 + sampler_key: Fixed # a string, indicating the sampling method. Support [Fixed, Uniform, Poisson]. + sampler_num: 8 # a integer number, indicating the sampling rate of the sampling method, supported when `sampler_key` is Fixed or Poisson. + img_num: 4 + res_img_figsize: null + save_res_path: ./inference/predicted + save_npy: false diff --git a/examples/topopt/topopt.py b/examples/topopt/topopt.py index 0a9768bfd..a4b1d240b 100644 --- a/examples/topopt/topopt.py +++ b/examples/topopt/topopt.py @@ -17,13 +17,16 @@ import functions as func_module import h5py import hydra +import matplotlib.pyplot as plt import numpy as np import paddle from omegaconf import DictConfig from paddle import nn +from paddle.static import InputSpec from topoptmodel import TopOptNN import ppsci +from deploy.python_infer import pinn_predictor from ppsci.utils import logger @@ -120,7 +123,7 @@ def evaluate(cfg: DictConfig): # fixed iteration stop times for evaluation iterations_stop_times = range(5, 85, 5) - model = TopOptNN() + model = TopOptNN(**cfg.MODEL) # evaluation for 4 cases acc_results_summary = {} @@ -317,14 +320,120 @@ def val_metric(output_dict, label_dict, weight_dict=None): return {"Binary_Acc": acc, "IoU": iou} +# export model +def export(cfg: DictConfig): + # set model + model = TopOptNN(**cfg.MODEL) + + # initialize solver + solver = ppsci.solver.Solver( + model, + eval_with_no_grad=True, + pretrained_model_path=cfg.INFER.pretrained_model_path_dict[ + cfg.EXPORT.pretrained_model_name + ], + ) + + # export model + input_spec = [{"input": InputSpec([None, 2, 40, 40], "float32", name="input")}] + + solver.export(input_spec, cfg.INFER.export_path) + + +# model inference +def inference(cfg: DictConfig): + # read h5 data + h5data = h5py.File(cfg.DATA_PATH, "r") + data_iters = np.array(h5data["iters"]) + data_targets = np.array(h5data["targets"]) + idx = np.random.choice(len(data_iters), cfg.INFER.img_num, False) + data_iters = data_iters[idx] + data_targets = data_targets[idx] + + sampler = func_module.generate_sampler(cfg.INFER.sampler_key, cfg.INFER.sampler_num) + data_iters = channel_sampling(sampler, data_iters) + + predictor = pinn_predictor.PINNPredictor(cfg) + + input_dict = {"input": data_iters} + output_dict = predictor.predict(input_dict, cfg.INFER.batch_size) + + # mapping data to output_key + output_dict = { + store_key: output_dict[infer_key] + for store_key, infer_key in zip({"output"}, output_dict.keys()) + } + + save_topopt_img( + input_dict, + output_dict, + data_iters, + cfg.INFER.save_res_path, + cfg.INFER.res_img_figsize, + cfg.INFER.save_npy, + ) + + +# used for inference +def channel_sampling(sampler, input): + SIMP_initial_iter_time = sampler() + input_channel_k = input[:, SIMP_initial_iter_time, :, :] + input_channel_k_minus_1 = input[:, SIMP_initial_iter_time - 1, :, :] + input = np.stack( + (input_channel_k, input_channel_k - input_channel_k_minus_1), axis=1 + ) + return input + + +# used for inference +def save_topopt_img( + input_dict, output_dict, ground_truth, res_path, figsize=None, npy=False +): + + input = input_dict["input"] + output = output_dict["output"] + for i in range(len(input)): + plt.figure(figsize=figsize) + plt.subplot(1, 4, 1) + plt.axis("off") + plt.imshow(input[i][0], cmap="gray") + plt.title("Input Image") + plt.subplot(1, 4, 2) + plt.axis("off") + plt.imshow(input[i][1], cmap="gray") + plt.title("Input Gradient") + plt.subplot(1, 4, 3) + plt.axis("off") + plt.imshow(np.round(output[i][0]), cmap="gray") + print(output[i]) + plt.title("Prediction") + plt.subplot(1, 4, 4) + plt.axis("off") + plt.imshow(np.round(ground_truth[i][0]), cmap="gray") + print(ground_truth[i]) + plt.title("Ground Truth") + plt.show() + plt.savefig(osp.join(res_path, "Prediction_" + str(i) + ".png")) + plt.close() + if npy: + with open(osp(res_path, "Prediction_" + str(i) + ".npy"), "wb") as f: + np.save(f, output[i]) + + @hydra.main(version_base=None, config_path="./conf", config_name="topopt.yaml") def main(cfg: DictConfig): if cfg.mode == "train": train(cfg) elif cfg.mode == "eval": evaluate(cfg) + elif cfg.mode == "export": + export(cfg) + elif cfg.mode == "infer": + inference(cfg) else: - raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + raise ValueError( + f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'" + ) if __name__ == "__main__": diff --git a/examples/topopt/topoptmodel.py b/examples/topopt/topoptmodel.py index ba35dc6a8..07dc82a53 100644 --- a/examples/topopt/topoptmodel.py +++ b/examples/topopt/topoptmodel.py @@ -32,7 +32,10 @@ class TopOptNN(ppsci.arch.UNetEx): kernel_size (int, optional): Size of kernel of convolution layer. Defaults to 3. filters (Tuple[int, ...], optional): Number of filters. Defaults to (16, 32, 64). layers (int, optional): Number of encoders or decoders. Defaults to 3. - channel_sampler (callable): The sampling function for the initial iteration time (corresponding to the channel number of the input) of SIMP algorithm. + channel_sampler (callable, optional): The sampling function for the initial iteration time + (corresponding to the channel number of the input) of SIMP algorithm. The default value + is None, when it is None, input for the forward method should be sampled and prepared + with the shape of [batch, 2, height, width] before passing to forward method. weight_norm (bool, optional): Whether use weight normalization layer. Defaults to True. batch_norm (bool, optional): Whether add batch normalization layer. Defaults to True. activation (Type[nn.Layer], optional): Name of activation function. Defaults to nn.ReLU. @@ -51,7 +54,7 @@ def __init__( kernel_size=3, filters=(16, 32, 64), layers=2, - channel_sampler=lambda: 1, + channel_sampler=None, weight_norm=False, batch_norm=False, activation=nn.ReLU, @@ -124,15 +127,17 @@ def __init__( ) def forward(self, x): - SIMP_initial_iter_time = self.channel_sampler() # channel k - input_channel_k = x[self.input_keys[0]][:, SIMP_initial_iter_time, :, :] - input_channel_k_minus_1 = x[self.input_keys[0]][ - :, SIMP_initial_iter_time - 1, :, : - ] - x = paddle.stack( - (input_channel_k, input_channel_k - input_channel_k_minus_1), axis=1 - ) - + if self.channel_sampler is not None: + SIMP_initial_iter_time = self.channel_sampler() # channel k + input_channel_k = x[self.input_keys[0]][:, SIMP_initial_iter_time, :, :] + input_channel_k_minus_1 = x[self.input_keys[0]][ + :, SIMP_initial_iter_time - 1, :, : + ] + x = paddle.stack( + (input_channel_k, input_channel_k - input_channel_k_minus_1), axis=1 + ) + else: + x = x[self.input_keys[0]] # encode upsampling_size = [] skip_connection = [] From 169c5183989190f1a4e2422ccdb3286c85de761a Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 4 Mar 2024 10:12:24 +0800 Subject: [PATCH 2/2] update --- docs/zh/examples/topopt.md | 38 +++++++++++++++++++++++++++++++- examples/topopt/conf/topopt.yaml | 6 ++--- examples/topopt/topopt.py | 33 +++++++++++++++++---------- 3 files changed, 61 insertions(+), 16 deletions(-) diff --git a/docs/zh/examples/topopt.md b/docs/zh/examples/topopt.md index 2722b95ce..5f61ae98d 100644 --- a/docs/zh/examples/topopt.md +++ b/docs/zh/examples/topopt.md @@ -28,6 +28,18 @@ python topopt.py mode=export INFER.pretrained_model_name=Uniform ``` + ``` sh + python topopt.py mode=export INFER.pretrained_model_name=Poisson5 + ``` + + ``` sh + python topopt.py mode=export INFER.pretrained_model_name=Poisson10 + ``` + + ``` sh + python topopt.py mode=export INFER.pretrained_model_name=Poisson30 + ``` + === "模型推理命令" ``` sh @@ -35,7 +47,31 @@ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 -P ./datasets/ # windows # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 --output ./datasets/top_dataset.h5 - python topopt.py mode=infer INFER.img_num=3 + python topopt.py mode=infer INFER.pretrained_model_name=Uniform INFER.img_num=3 + ``` + + ``` sh + # linux + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 -P ./datasets/ + # windows + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 --output ./datasets/top_dataset.h5 + python topopt.py mode=infer INFER.pretrained_model_name=Poisson5 INFER.img_num=3 + ``` + + ``` sh + # linux + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 -P ./datasets/ + # windows + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 --output ./datasets/top_dataset.h5 + python topopt.py mode=infer INFER.pretrained_model_name=Poisson10 INFER.img_num=3 + ``` + + ``` sh + # linux + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 -P ./datasets/ + # windows + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 --output ./datasets/top_dataset.h5 + python topopt.py mode=infer INFER.pretrained_model_name=Poisson30 INFER.img_num=3 ``` | 预训练模型 | 指标 | diff --git a/examples/topopt/conf/topopt.yaml b/examples/topopt/conf/topopt.yaml index 6ff8b8ad8..05130f7fd 100644 --- a/examples/topopt/conf/topopt.yaml +++ b/examples/topopt/conf/topopt.yaml @@ -14,7 +14,7 @@ hydra: - EVAL.pretrained_model_path_dict - EVAL.batch_size - EVAL.num_val_step - - EXPORT.pretrained_model_name + - INFER.pretrained_model_name - INFER.pretrained_model_path_dict - INFER.export_path - INFER.batch_size @@ -68,7 +68,7 @@ EVAL: INFER: pretrained_model_name: null # a string, indicating which model you want to export. Support [Uniform, Poisson5, Poisson10, Poisson30]. pretrained_model_path_dict: {'Uniform': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/uniform_pretrained.pdparams', 'Poisson5': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson5_pretrained.pdparams', 'Poisson10': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson10_pretrained.pdparams', 'Poisson30': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson30_pretrained.pdparams'} - export_path: ./inference/topopt + export_path: ./inference/topopt_${INFER.pretrained_model_name} pdmodel_path: ${INFER.export_path}.pdmodel pdpiparams_path: ${INFER.export_path}.pdiparams device: gpu @@ -86,5 +86,5 @@ INFER: sampler_num: 8 # a integer number, indicating the sampling rate of the sampling method, supported when `sampler_key` is Fixed or Poisson. img_num: 4 res_img_figsize: null - save_res_path: ./inference/predicted + save_res_path: ./inference/predicted_${INFER.pretrained_model_name} save_npy: false diff --git a/examples/topopt/topopt.py b/examples/topopt/topopt.py index a4b1d240b..3346f9f2d 100644 --- a/examples/topopt/topopt.py +++ b/examples/topopt/topopt.py @@ -13,20 +13,18 @@ # limitations under the License. from os import path as osp +from typing import Dict import functions as func_module import h5py import hydra -import matplotlib.pyplot as plt import numpy as np import paddle from omegaconf import DictConfig from paddle import nn -from paddle.static import InputSpec from topoptmodel import TopOptNN import ppsci -from deploy.python_infer import pinn_predictor from ppsci.utils import logger @@ -330,17 +328,18 @@ def export(cfg: DictConfig): model, eval_with_no_grad=True, pretrained_model_path=cfg.INFER.pretrained_model_path_dict[ - cfg.EXPORT.pretrained_model_name + cfg.INFER.pretrained_model_name ], ) # export model + from paddle.static import InputSpec + input_spec = [{"input": InputSpec([None, 2, 40, 40], "float32", name="input")}] solver.export(input_spec, cfg.INFER.export_path) -# model inference def inference(cfg: DictConfig): # read h5 data h5data = h5py.File(cfg.DATA_PATH, "r") @@ -353,6 +352,8 @@ def inference(cfg: DictConfig): sampler = func_module.generate_sampler(cfg.INFER.sampler_key, cfg.INFER.sampler_num) data_iters = channel_sampling(sampler, data_iters) + from deploy.python_infer import pinn_predictor + predictor = pinn_predictor.PINNPredictor(cfg) input_dict = {"input": data_iters} @@ -367,7 +368,7 @@ def inference(cfg: DictConfig): save_topopt_img( input_dict, output_dict, - data_iters, + data_targets, cfg.INFER.save_res_path, cfg.INFER.res_img_figsize, cfg.INFER.save_npy, @@ -387,11 +388,21 @@ def channel_sampling(sampler, input): # used for inference def save_topopt_img( - input_dict, output_dict, ground_truth, res_path, figsize=None, npy=False + input_dict: Dict[str, np.ndarray], + output_dict: Dict[str, np.ndarray], + ground_truth: np.ndarray, + save_dir: str, + figsize: tuple = None, + save_npy: bool = False, ): input = input_dict["input"] output = output_dict["output"] + import os + + import matplotlib.pyplot as plt + + os.makedirs(save_dir, exist_ok=True) for i in range(len(input)): plt.figure(figsize=figsize) plt.subplot(1, 4, 1) @@ -405,18 +416,16 @@ def save_topopt_img( plt.subplot(1, 4, 3) plt.axis("off") plt.imshow(np.round(output[i][0]), cmap="gray") - print(output[i]) plt.title("Prediction") plt.subplot(1, 4, 4) plt.axis("off") plt.imshow(np.round(ground_truth[i][0]), cmap="gray") - print(ground_truth[i]) plt.title("Ground Truth") plt.show() - plt.savefig(osp.join(res_path, "Prediction_" + str(i) + ".png")) + plt.savefig(osp.join(save_dir, f"Prediction_{i}.png")) plt.close() - if npy: - with open(osp(res_path, "Prediction_" + str(i) + ".npy"), "wb") as f: + if save_npy: + with open(osp(save_dir, f"Prediction_{i}.npy"), "wb") as f: np.save(f, output[i])