Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PPSCI Export&Infer No. 29】 add export and inference #793

Merged
merged 2 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions docs/zh/examples/topopt.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,58 @@
python topopt.py mode=eval 'EVAL.pretrained_model_path_dict={'Uniform': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/uniform_pretrained.pdparams', 'Poisson5': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson5_pretrained.pdparams', 'Poisson10': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson10_pretrained.pdparams', 'Poisson30': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson30_pretrained.pdparams'}'
```

=== "模型导出命令"

``` sh
python topopt.py mode=export INFER.pretrained_model_name=Uniform
Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate Mar 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否可以以注释的形式补充另外几个Poisson系列的模型呢?模型推理命令同

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

```

``` sh
python topopt.py mode=export INFER.pretrained_model_name=Poisson5
```

``` sh
python topopt.py mode=export INFER.pretrained_model_name=Poisson10
```

``` sh
python topopt.py mode=export INFER.pretrained_model_name=Poisson30
```

=== "模型推理命令"

``` sh
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 -P ./datasets/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 --output ./datasets/top_dataset.h5
python topopt.py mode=infer INFER.pretrained_model_name=Uniform INFER.img_num=3
```

``` sh
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 -P ./datasets/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 --output ./datasets/top_dataset.h5
python topopt.py mode=infer INFER.pretrained_model_name=Poisson5 INFER.img_num=3
```

``` sh
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 -P ./datasets/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 --output ./datasets/top_dataset.h5
python topopt.py mode=infer INFER.pretrained_model_name=Poisson10 INFER.img_num=3
```

``` sh
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 -P ./datasets/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5 --output ./datasets/top_dataset.h5
python topopt.py mode=infer INFER.pretrained_model_name=Poisson30 INFER.img_num=3
```

| 预训练模型 | 指标 |
|:--| :--|
| [topopt_uniform_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/uniform_pretrained.pdparams) | loss(sup_validator): [0.14336, 0.10211, 0.07927, 0.06433, 0.04970, 0.04612, 0.04201, 0.03566, 0.03623, 0.03314, 0.02929, 0.02857, 0.02498, 0.02517, 0.02523, 0.02618]<br>metric.Binary_Acc(sup_validator): [0.9410, 0.9673, 0.9718, 0.9727, 0.9818, 0.9824, 0.9826, 0.9845, 0.9856, 0.9892, 0.9892, 0.9907, 0.9890, 0.9916, 0.9914, 0.9922]<br>metric.IoU(sup_validator): [0.8887, 0.9367, 0.9452, 0.9468, 0.9644, 0.9655, 0.9659, 0.9695, 0.9717, 0.9787, 0.9787, 0.9816, 0.9784, 0.9835, 0.9831, 0.9845] |
Expand Down
31 changes: 31 additions & 0 deletions examples/topopt/conf/topopt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@ hydra:
- EVAL.pretrained_model_path_dict
- EVAL.batch_size
- EVAL.num_val_step
- INFER.pretrained_model_name
- INFER.pretrained_model_path_dict
- INFER.export_path
- INFER.batch_size
- mode
- vol_coeff
- log_freq
sweep:
# output directory for multirun
dir: ${hydra.run.dir}
Expand All @@ -25,6 +30,7 @@ hydra:
mode: train # running mode: train/eval
seed: 42
output_dir: ${hydra:run.dir}
log_freq: 20

# set default cases parameters
CASE_PARAM: [[Poisson, 5], [Poisson, 10], [Poisson, 30], [Uniform, null]]
Expand Down Expand Up @@ -57,3 +63,28 @@ EVAL:
pretrained_model_path_dict: null # a dict: {casename1:path1, casename2:path2, casename3:path3, casename4:path4}
num_val_step: 10 # the number of iteration for each evaluation case
batch_size: 16

# inference settings
INFER:
pretrained_model_name: null # a string, indicating which model you want to export. Support [Uniform, Poisson5, Poisson10, Poisson30].
pretrained_model_path_dict: {'Uniform': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/uniform_pretrained.pdparams', 'Poisson5': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson5_pretrained.pdparams', 'Poisson10': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson10_pretrained.pdparams', 'Poisson30': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson30_pretrained.pdparams'}
export_path: ./inference/topopt_${INFER.pretrained_model_name}
pdmodel_path: ${INFER.export_path}.pdmodel
pdpiparams_path: ${INFER.export_path}.pdiparams
device: gpu
engine: native
precision: fp32
onnx_path: null
ir_optim: true
min_subgraph_size: 30
gpu_mem: 4000
gpu_id: 0
max_batch_size: 1024
num_cpu_threads: 10
batch_size: 4
sampler_key: Fixed # a string, indicating the sampling method. Support [Fixed, Uniform, Poisson].
sampler_num: 8 # a integer number, indicating the sampling rate of the sampling method, supported when `sampler_key` is Fixed or Poisson.
img_num: 4
res_img_figsize: null
save_res_path: ./inference/predicted_${INFER.pretrained_model_name}
save_npy: false
122 changes: 120 additions & 2 deletions examples/topopt/topopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from os import path as osp
from typing import Dict

import functions as func_module
import h5py
Expand Down Expand Up @@ -120,7 +121,7 @@ def evaluate(cfg: DictConfig):

# fixed iteration stop times for evaluation
iterations_stop_times = range(5, 85, 5)
model = TopOptNN()
model = TopOptNN(**cfg.MODEL)

# evaluation for 4 cases
acc_results_summary = {}
Expand Down Expand Up @@ -317,14 +318,131 @@ def val_metric(output_dict, label_dict, weight_dict=None):
return {"Binary_Acc": acc, "IoU": iou}


# export model
def export(cfg: DictConfig):
# set model
model = TopOptNN(**cfg.MODEL)

# initialize solver
solver = ppsci.solver.Solver(
model,
eval_with_no_grad=True,
pretrained_model_path=cfg.INFER.pretrained_model_path_dict[
cfg.INFER.pretrained_model_name
],
)

# export model
from paddle.static import InputSpec

input_spec = [{"input": InputSpec([None, 2, 40, 40], "float32", name="input")}]

solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
# read h5 data
h5data = h5py.File(cfg.DATA_PATH, "r")
data_iters = np.array(h5data["iters"])
data_targets = np.array(h5data["targets"])
idx = np.random.choice(len(data_iters), cfg.INFER.img_num, False)
data_iters = data_iters[idx]
data_targets = data_targets[idx]

sampler = func_module.generate_sampler(cfg.INFER.sampler_key, cfg.INFER.sampler_num)
data_iters = channel_sampling(sampler, data_iters)

from deploy.python_infer import pinn_predictor

predictor = pinn_predictor.PINNPredictor(cfg)

input_dict = {"input": data_iters}
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)

# mapping data to output_key
output_dict = {
store_key: output_dict[infer_key]
for store_key, infer_key in zip({"output"}, output_dict.keys())
}

save_topopt_img(
input_dict,
output_dict,
data_targets,
cfg.INFER.save_res_path,
cfg.INFER.res_img_figsize,
cfg.INFER.save_npy,
)


# used for inference
def channel_sampling(sampler, input):
SIMP_initial_iter_time = sampler()
input_channel_k = input[:, SIMP_initial_iter_time, :, :]
input_channel_k_minus_1 = input[:, SIMP_initial_iter_time - 1, :, :]
input = np.stack(
(input_channel_k, input_channel_k - input_channel_k_minus_1), axis=1
)
return input


# used for inference
def save_topopt_img(
input_dict: Dict[str, np.ndarray],
output_dict: Dict[str, np.ndarray],
ground_truth: np.ndarray,
save_dir: str,
figsize: tuple = None,
save_npy: bool = False,
):

input = input_dict["input"]
output = output_dict["output"]
Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate Mar 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

保存前先创建文件夹,否则会报错

Suggested change
output = output_dict["output"]
output = output_dict["output"]
import os
os.makedirs(res_path, exist_ok=True)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

import os

import matplotlib.pyplot as plt

os.makedirs(save_dir, exist_ok=True)
for i in range(len(input)):
plt.figure(figsize=figsize)
plt.subplot(1, 4, 1)
plt.axis("off")
plt.imshow(input[i][0], cmap="gray")
plt.title("Input Image")
plt.subplot(1, 4, 2)
plt.axis("off")
plt.imshow(input[i][1], cmap="gray")
plt.title("Input Gradient")
plt.subplot(1, 4, 3)
plt.axis("off")
plt.imshow(np.round(output[i][0]), cmap="gray")
plt.title("Prediction")
plt.subplot(1, 4, 4)
plt.axis("off")
plt.imshow(np.round(ground_truth[i][0]), cmap="gray")
plt.title("Ground Truth")
plt.show()
plt.savefig(osp.join(save_dir, f"Prediction_{i}.png"))
plt.close()
if save_npy:
with open(osp(save_dir, f"Prediction_{i}.npy"), "wb") as f:
np.save(f, output[i])


@hydra.main(version_base=None, config_path="./conf", config_name="topopt.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
elif cfg.mode == "export":
export(cfg)
elif cfg.mode == "infer":
inference(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
raise ValueError(
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
)


if __name__ == "__main__":
Expand Down
27 changes: 16 additions & 11 deletions examples/topopt/topoptmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ class TopOptNN(ppsci.arch.UNetEx):
kernel_size (int, optional): Size of kernel of convolution layer. Defaults to 3.
filters (Tuple[int, ...], optional): Number of filters. Defaults to (16, 32, 64).
layers (int, optional): Number of encoders or decoders. Defaults to 3.
channel_sampler (callable): The sampling function for the initial iteration time (corresponding to the channel number of the input) of SIMP algorithm.
channel_sampler (callable, optional): The sampling function for the initial iteration time
(corresponding to the channel number of the input) of SIMP algorithm. The default value
is None, when it is None, input for the forward method should be sampled and prepared
with the shape of [batch, 2, height, width] before passing to forward method.
weight_norm (bool, optional): Whether use weight normalization layer. Defaults to True.
batch_norm (bool, optional): Whether add batch normalization layer. Defaults to True.
activation (Type[nn.Layer], optional): Name of activation function. Defaults to nn.ReLU.
Expand All @@ -51,7 +54,7 @@ def __init__(
kernel_size=3,
filters=(16, 32, 64),
layers=2,
channel_sampler=lambda: 1,
channel_sampler=None,
weight_norm=False,
batch_norm=False,
activation=nn.ReLU,
Expand Down Expand Up @@ -124,15 +127,17 @@ def __init__(
)

def forward(self, x):
SIMP_initial_iter_time = self.channel_sampler() # channel k
input_channel_k = x[self.input_keys[0]][:, SIMP_initial_iter_time, :, :]
input_channel_k_minus_1 = x[self.input_keys[0]][
:, SIMP_initial_iter_time - 1, :, :
]
x = paddle.stack(
(input_channel_k, input_channel_k - input_channel_k_minus_1), axis=1
)

if self.channel_sampler is not None:
SIMP_initial_iter_time = self.channel_sampler() # channel k
input_channel_k = x[self.input_keys[0]][:, SIMP_initial_iter_time, :, :]
input_channel_k_minus_1 = x[self.input_keys[0]][
:, SIMP_initial_iter_time - 1, :, :
]
x = paddle.stack(
(input_channel_k, input_channel_k - input_channel_k_minus_1), axis=1
)
else:
x = x[self.input_keys[0]]
# encode
upsampling_size = []
skip_connection = []
Expand Down