From 08685309a2fc46ce67fd65e36d50aa218ac0c44d Mon Sep 17 00:00:00 2001
From: AyaseNana <13659110308@163.com>
Date: Sat, 2 Mar 2024 21:43:08 +0800
Subject: [PATCH 1/2] add export and inference
---
docs/zh/examples/topopt.md | 16 +++++
examples/topopt/conf/topopt.yaml | 31 +++++++++
examples/topopt/topopt.py | 113 ++++++++++++++++++++++++++++++-
examples/topopt/topoptmodel.py | 27 +++++---
4 files changed, 174 insertions(+), 13 deletions(-)
diff --git a/docs/zh/examples/topopt.md b/docs/zh/examples/topopt.md
index 1cdd1c0a1..2722b95ce 100644
--- a/docs/zh/examples/topopt.md
+++ b/docs/zh/examples/topopt.md
@@ -22,6 +22,22 @@
python topopt.py mode=eval 'EVAL.pretrained_model_path_dict={'Uniform': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/uniform_pretrained.pdparams', 'Poisson5': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson5_pretrained.pdparams', 'Poisson10': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson10_pretrained.pdparams', 'Poisson30': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson30_pretrained.pdparams'}'
```
+=== "模型导出命令"
+
+ ``` sh
+ python topopt.py mode=export INFER.pretrained_model_name=Uniform
+ ```
+
+=== "模型推理命令"
+
+ ``` sh
+ # linux
+ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 -P ./datasets/
+ # windows
+ # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 --output ./datasets/top_dataset.h5
+ python topopt.py mode=infer INFER.img_num=3
+ ```
+
| 预训练模型 | 指标 |
|:--| :--|
| [topopt_uniform_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/uniform_pretrained.pdparams) | loss(sup_validator): [0.14336, 0.10211, 0.07927, 0.06433, 0.04970, 0.04612, 0.04201, 0.03566, 0.03623, 0.03314, 0.02929, 0.02857, 0.02498, 0.02517, 0.02523, 0.02618]
metric.Binary_Acc(sup_validator): [0.9410, 0.9673, 0.9718, 0.9727, 0.9818, 0.9824, 0.9826, 0.9845, 0.9856, 0.9892, 0.9892, 0.9907, 0.9890, 0.9916, 0.9914, 0.9922]
metric.IoU(sup_validator): [0.8887, 0.9367, 0.9452, 0.9468, 0.9644, 0.9655, 0.9659, 0.9695, 0.9717, 0.9787, 0.9787, 0.9816, 0.9784, 0.9835, 0.9831, 0.9845] |
diff --git a/examples/topopt/conf/topopt.yaml b/examples/topopt/conf/topopt.yaml
index 6f6a7fbe7..6ff8b8ad8 100644
--- a/examples/topopt/conf/topopt.yaml
+++ b/examples/topopt/conf/topopt.yaml
@@ -14,8 +14,13 @@ hydra:
- EVAL.pretrained_model_path_dict
- EVAL.batch_size
- EVAL.num_val_step
+ - EXPORT.pretrained_model_name
+ - INFER.pretrained_model_path_dict
+ - INFER.export_path
+ - INFER.batch_size
- mode
- vol_coeff
+ - log_freq
sweep:
# output directory for multirun
dir: ${hydra.run.dir}
@@ -25,6 +30,7 @@ hydra:
mode: train # running mode: train/eval
seed: 42
output_dir: ${hydra:run.dir}
+log_freq: 20
# set default cases parameters
CASE_PARAM: [[Poisson, 5], [Poisson, 10], [Poisson, 30], [Uniform, null]]
@@ -57,3 +63,28 @@ EVAL:
pretrained_model_path_dict: null # a dict: {casename1:path1, casename2:path2, casename3:path3, casename4:path4}
num_val_step: 10 # the number of iteration for each evaluation case
batch_size: 16
+
+# inference settings
+INFER:
+ pretrained_model_name: null # a string, indicating which model you want to export. Support [Uniform, Poisson5, Poisson10, Poisson30].
+ pretrained_model_path_dict: {'Uniform': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/uniform_pretrained.pdparams', 'Poisson5': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson5_pretrained.pdparams', 'Poisson10': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson10_pretrained.pdparams', 'Poisson30': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson30_pretrained.pdparams'}
+ export_path: ./inference/topopt
+ pdmodel_path: ${INFER.export_path}.pdmodel
+ pdpiparams_path: ${INFER.export_path}.pdiparams
+ device: gpu
+ engine: native
+ precision: fp32
+ onnx_path: null
+ ir_optim: true
+ min_subgraph_size: 30
+ gpu_mem: 4000
+ gpu_id: 0
+ max_batch_size: 1024
+ num_cpu_threads: 10
+ batch_size: 4
+ sampler_key: Fixed # a string, indicating the sampling method. Support [Fixed, Uniform, Poisson].
+ sampler_num: 8 # a integer number, indicating the sampling rate of the sampling method, supported when `sampler_key` is Fixed or Poisson.
+ img_num: 4
+ res_img_figsize: null
+ save_res_path: ./inference/predicted
+ save_npy: false
diff --git a/examples/topopt/topopt.py b/examples/topopt/topopt.py
index 0a9768bfd..a4b1d240b 100644
--- a/examples/topopt/topopt.py
+++ b/examples/topopt/topopt.py
@@ -17,13 +17,16 @@
import functions as func_module
import h5py
import hydra
+import matplotlib.pyplot as plt
import numpy as np
import paddle
from omegaconf import DictConfig
from paddle import nn
+from paddle.static import InputSpec
from topoptmodel import TopOptNN
import ppsci
+from deploy.python_infer import pinn_predictor
from ppsci.utils import logger
@@ -120,7 +123,7 @@ def evaluate(cfg: DictConfig):
# fixed iteration stop times for evaluation
iterations_stop_times = range(5, 85, 5)
- model = TopOptNN()
+ model = TopOptNN(**cfg.MODEL)
# evaluation for 4 cases
acc_results_summary = {}
@@ -317,14 +320,120 @@ def val_metric(output_dict, label_dict, weight_dict=None):
return {"Binary_Acc": acc, "IoU": iou}
+# export model
+def export(cfg: DictConfig):
+ # set model
+ model = TopOptNN(**cfg.MODEL)
+
+ # initialize solver
+ solver = ppsci.solver.Solver(
+ model,
+ eval_with_no_grad=True,
+ pretrained_model_path=cfg.INFER.pretrained_model_path_dict[
+ cfg.EXPORT.pretrained_model_name
+ ],
+ )
+
+ # export model
+ input_spec = [{"input": InputSpec([None, 2, 40, 40], "float32", name="input")}]
+
+ solver.export(input_spec, cfg.INFER.export_path)
+
+
+# model inference
+def inference(cfg: DictConfig):
+ # read h5 data
+ h5data = h5py.File(cfg.DATA_PATH, "r")
+ data_iters = np.array(h5data["iters"])
+ data_targets = np.array(h5data["targets"])
+ idx = np.random.choice(len(data_iters), cfg.INFER.img_num, False)
+ data_iters = data_iters[idx]
+ data_targets = data_targets[idx]
+
+ sampler = func_module.generate_sampler(cfg.INFER.sampler_key, cfg.INFER.sampler_num)
+ data_iters = channel_sampling(sampler, data_iters)
+
+ predictor = pinn_predictor.PINNPredictor(cfg)
+
+ input_dict = {"input": data_iters}
+ output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
+
+ # mapping data to output_key
+ output_dict = {
+ store_key: output_dict[infer_key]
+ for store_key, infer_key in zip({"output"}, output_dict.keys())
+ }
+
+ save_topopt_img(
+ input_dict,
+ output_dict,
+ data_iters,
+ cfg.INFER.save_res_path,
+ cfg.INFER.res_img_figsize,
+ cfg.INFER.save_npy,
+ )
+
+
+# used for inference
+def channel_sampling(sampler, input):
+ SIMP_initial_iter_time = sampler()
+ input_channel_k = input[:, SIMP_initial_iter_time, :, :]
+ input_channel_k_minus_1 = input[:, SIMP_initial_iter_time - 1, :, :]
+ input = np.stack(
+ (input_channel_k, input_channel_k - input_channel_k_minus_1), axis=1
+ )
+ return input
+
+
+# used for inference
+def save_topopt_img(
+ input_dict, output_dict, ground_truth, res_path, figsize=None, npy=False
+):
+
+ input = input_dict["input"]
+ output = output_dict["output"]
+ for i in range(len(input)):
+ plt.figure(figsize=figsize)
+ plt.subplot(1, 4, 1)
+ plt.axis("off")
+ plt.imshow(input[i][0], cmap="gray")
+ plt.title("Input Image")
+ plt.subplot(1, 4, 2)
+ plt.axis("off")
+ plt.imshow(input[i][1], cmap="gray")
+ plt.title("Input Gradient")
+ plt.subplot(1, 4, 3)
+ plt.axis("off")
+ plt.imshow(np.round(output[i][0]), cmap="gray")
+ print(output[i])
+ plt.title("Prediction")
+ plt.subplot(1, 4, 4)
+ plt.axis("off")
+ plt.imshow(np.round(ground_truth[i][0]), cmap="gray")
+ print(ground_truth[i])
+ plt.title("Ground Truth")
+ plt.show()
+ plt.savefig(osp.join(res_path, "Prediction_" + str(i) + ".png"))
+ plt.close()
+ if npy:
+ with open(osp(res_path, "Prediction_" + str(i) + ".npy"), "wb") as f:
+ np.save(f, output[i])
+
+
@hydra.main(version_base=None, config_path="./conf", config_name="topopt.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
+ elif cfg.mode == "export":
+ export(cfg)
+ elif cfg.mode == "infer":
+ inference(cfg)
else:
- raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
+ raise ValueError(
+ f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
+ )
if __name__ == "__main__":
diff --git a/examples/topopt/topoptmodel.py b/examples/topopt/topoptmodel.py
index ba35dc6a8..07dc82a53 100644
--- a/examples/topopt/topoptmodel.py
+++ b/examples/topopt/topoptmodel.py
@@ -32,7 +32,10 @@ class TopOptNN(ppsci.arch.UNetEx):
kernel_size (int, optional): Size of kernel of convolution layer. Defaults to 3.
filters (Tuple[int, ...], optional): Number of filters. Defaults to (16, 32, 64).
layers (int, optional): Number of encoders or decoders. Defaults to 3.
- channel_sampler (callable): The sampling function for the initial iteration time (corresponding to the channel number of the input) of SIMP algorithm.
+ channel_sampler (callable, optional): The sampling function for the initial iteration time
+ (corresponding to the channel number of the input) of SIMP algorithm. The default value
+ is None, when it is None, input for the forward method should be sampled and prepared
+ with the shape of [batch, 2, height, width] before passing to forward method.
weight_norm (bool, optional): Whether use weight normalization layer. Defaults to True.
batch_norm (bool, optional): Whether add batch normalization layer. Defaults to True.
activation (Type[nn.Layer], optional): Name of activation function. Defaults to nn.ReLU.
@@ -51,7 +54,7 @@ def __init__(
kernel_size=3,
filters=(16, 32, 64),
layers=2,
- channel_sampler=lambda: 1,
+ channel_sampler=None,
weight_norm=False,
batch_norm=False,
activation=nn.ReLU,
@@ -124,15 +127,17 @@ def __init__(
)
def forward(self, x):
- SIMP_initial_iter_time = self.channel_sampler() # channel k
- input_channel_k = x[self.input_keys[0]][:, SIMP_initial_iter_time, :, :]
- input_channel_k_minus_1 = x[self.input_keys[0]][
- :, SIMP_initial_iter_time - 1, :, :
- ]
- x = paddle.stack(
- (input_channel_k, input_channel_k - input_channel_k_minus_1), axis=1
- )
-
+ if self.channel_sampler is not None:
+ SIMP_initial_iter_time = self.channel_sampler() # channel k
+ input_channel_k = x[self.input_keys[0]][:, SIMP_initial_iter_time, :, :]
+ input_channel_k_minus_1 = x[self.input_keys[0]][
+ :, SIMP_initial_iter_time - 1, :, :
+ ]
+ x = paddle.stack(
+ (input_channel_k, input_channel_k - input_channel_k_minus_1), axis=1
+ )
+ else:
+ x = x[self.input_keys[0]]
# encode
upsampling_size = []
skip_connection = []
From 169c5183989190f1a4e2422ccdb3286c85de761a Mon Sep 17 00:00:00 2001
From: AyaseNana <13659110308@163.com>
Date: Mon, 4 Mar 2024 10:12:24 +0800
Subject: [PATCH 2/2] update
---
docs/zh/examples/topopt.md | 38 +++++++++++++++++++++++++++++++-
examples/topopt/conf/topopt.yaml | 6 ++---
examples/topopt/topopt.py | 33 +++++++++++++++++----------
3 files changed, 61 insertions(+), 16 deletions(-)
diff --git a/docs/zh/examples/topopt.md b/docs/zh/examples/topopt.md
index 2722b95ce..5f61ae98d 100644
--- a/docs/zh/examples/topopt.md
+++ b/docs/zh/examples/topopt.md
@@ -28,6 +28,18 @@
python topopt.py mode=export INFER.pretrained_model_name=Uniform
```
+ ``` sh
+ python topopt.py mode=export INFER.pretrained_model_name=Poisson5
+ ```
+
+ ``` sh
+ python topopt.py mode=export INFER.pretrained_model_name=Poisson10
+ ```
+
+ ``` sh
+ python topopt.py mode=export INFER.pretrained_model_name=Poisson30
+ ```
+
=== "模型推理命令"
``` sh
@@ -35,7 +47,31 @@
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 -P ./datasets/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 --output ./datasets/top_dataset.h5
- python topopt.py mode=infer INFER.img_num=3
+ python topopt.py mode=infer INFER.pretrained_model_name=Uniform INFER.img_num=3
+ ```
+
+ ``` sh
+ # linux
+ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 -P ./datasets/
+ # windows
+ # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 --output ./datasets/top_dataset.h5
+ python topopt.py mode=infer INFER.pretrained_model_name=Poisson5 INFER.img_num=3
+ ```
+
+ ``` sh
+ # linux
+ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 -P ./datasets/
+ # windows
+ # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 --output ./datasets/top_dataset.h5
+ python topopt.py mode=infer INFER.pretrained_model_name=Poisson10 INFER.img_num=3
+ ```
+
+ ``` sh
+ # linux
+ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 -P ./datasets/
+ # windows
+ # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 --output ./datasets/top_dataset.h5
+ python topopt.py mode=infer INFER.pretrained_model_name=Poisson30 INFER.img_num=3
```
| 预训练模型 | 指标 |
diff --git a/examples/topopt/conf/topopt.yaml b/examples/topopt/conf/topopt.yaml
index 6ff8b8ad8..05130f7fd 100644
--- a/examples/topopt/conf/topopt.yaml
+++ b/examples/topopt/conf/topopt.yaml
@@ -14,7 +14,7 @@ hydra:
- EVAL.pretrained_model_path_dict
- EVAL.batch_size
- EVAL.num_val_step
- - EXPORT.pretrained_model_name
+ - INFER.pretrained_model_name
- INFER.pretrained_model_path_dict
- INFER.export_path
- INFER.batch_size
@@ -68,7 +68,7 @@ EVAL:
INFER:
pretrained_model_name: null # a string, indicating which model you want to export. Support [Uniform, Poisson5, Poisson10, Poisson30].
pretrained_model_path_dict: {'Uniform': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/uniform_pretrained.pdparams', 'Poisson5': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson5_pretrained.pdparams', 'Poisson10': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson10_pretrained.pdparams', 'Poisson30': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson30_pretrained.pdparams'}
- export_path: ./inference/topopt
+ export_path: ./inference/topopt_${INFER.pretrained_model_name}
pdmodel_path: ${INFER.export_path}.pdmodel
pdpiparams_path: ${INFER.export_path}.pdiparams
device: gpu
@@ -86,5 +86,5 @@ INFER:
sampler_num: 8 # a integer number, indicating the sampling rate of the sampling method, supported when `sampler_key` is Fixed or Poisson.
img_num: 4
res_img_figsize: null
- save_res_path: ./inference/predicted
+ save_res_path: ./inference/predicted_${INFER.pretrained_model_name}
save_npy: false
diff --git a/examples/topopt/topopt.py b/examples/topopt/topopt.py
index a4b1d240b..3346f9f2d 100644
--- a/examples/topopt/topopt.py
+++ b/examples/topopt/topopt.py
@@ -13,20 +13,18 @@
# limitations under the License.
from os import path as osp
+from typing import Dict
import functions as func_module
import h5py
import hydra
-import matplotlib.pyplot as plt
import numpy as np
import paddle
from omegaconf import DictConfig
from paddle import nn
-from paddle.static import InputSpec
from topoptmodel import TopOptNN
import ppsci
-from deploy.python_infer import pinn_predictor
from ppsci.utils import logger
@@ -330,17 +328,18 @@ def export(cfg: DictConfig):
model,
eval_with_no_grad=True,
pretrained_model_path=cfg.INFER.pretrained_model_path_dict[
- cfg.EXPORT.pretrained_model_name
+ cfg.INFER.pretrained_model_name
],
)
# export model
+ from paddle.static import InputSpec
+
input_spec = [{"input": InputSpec([None, 2, 40, 40], "float32", name="input")}]
solver.export(input_spec, cfg.INFER.export_path)
-# model inference
def inference(cfg: DictConfig):
# read h5 data
h5data = h5py.File(cfg.DATA_PATH, "r")
@@ -353,6 +352,8 @@ def inference(cfg: DictConfig):
sampler = func_module.generate_sampler(cfg.INFER.sampler_key, cfg.INFER.sampler_num)
data_iters = channel_sampling(sampler, data_iters)
+ from deploy.python_infer import pinn_predictor
+
predictor = pinn_predictor.PINNPredictor(cfg)
input_dict = {"input": data_iters}
@@ -367,7 +368,7 @@ def inference(cfg: DictConfig):
save_topopt_img(
input_dict,
output_dict,
- data_iters,
+ data_targets,
cfg.INFER.save_res_path,
cfg.INFER.res_img_figsize,
cfg.INFER.save_npy,
@@ -387,11 +388,21 @@ def channel_sampling(sampler, input):
# used for inference
def save_topopt_img(
- input_dict, output_dict, ground_truth, res_path, figsize=None, npy=False
+ input_dict: Dict[str, np.ndarray],
+ output_dict: Dict[str, np.ndarray],
+ ground_truth: np.ndarray,
+ save_dir: str,
+ figsize: tuple = None,
+ save_npy: bool = False,
):
input = input_dict["input"]
output = output_dict["output"]
+ import os
+
+ import matplotlib.pyplot as plt
+
+ os.makedirs(save_dir, exist_ok=True)
for i in range(len(input)):
plt.figure(figsize=figsize)
plt.subplot(1, 4, 1)
@@ -405,18 +416,16 @@ def save_topopt_img(
plt.subplot(1, 4, 3)
plt.axis("off")
plt.imshow(np.round(output[i][0]), cmap="gray")
- print(output[i])
plt.title("Prediction")
plt.subplot(1, 4, 4)
plt.axis("off")
plt.imshow(np.round(ground_truth[i][0]), cmap="gray")
- print(ground_truth[i])
plt.title("Ground Truth")
plt.show()
- plt.savefig(osp.join(res_path, "Prediction_" + str(i) + ".png"))
+ plt.savefig(osp.join(save_dir, f"Prediction_{i}.png"))
plt.close()
- if npy:
- with open(osp(res_path, "Prediction_" + str(i) + ".npy"), "wb") as f:
+ if save_npy:
+ with open(osp(save_dir, f"Prediction_{i}.npy"), "wb") as f:
np.save(f, output[i])