Skip to content

Commit 965156d

Browse files
【SCU】【PPSCI Export&Infer No.32】phygeonet (#1036)
* add_phygeonet * fix codestyle * fix codestyle * fix * fix codestyle * Update docs/zh/examples/phygeonet.md * Update docs/zh/examples/phygeonet.md * Update docs/zh/examples/phygeonet.md * Update docs/zh/examples/phygeonet.md * Update examples/phygeonet/conf/heat_equation.yaml --------- Co-authored-by: HydrogenSulfate <490868991@qq.com>
1 parent 932fe0d commit 965156d

File tree

5 files changed

+281
-2
lines changed

5 files changed

+281
-2
lines changed

docs/zh/examples/phygeonet.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,40 @@
5353

5454
```
5555

56+
=== "模型导出命令"
57+
58+
``` sh
59+
# heat_equation
60+
python heat_equation.py mode=export
61+
62+
# heat_equation_bc
63+
python heat_equation_with_bc.py mode=export
64+
```
65+
66+
=== "模型推理命令"
67+
68+
``` sh
69+
# heat_equation
70+
# linux
71+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyGeoNet/heat_equation.npz -P ./data/
72+
73+
# windows
74+
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyGeoNet/heat_equation.npz --create-dirs -o ./data/heat_equation.npz
75+
76+
python heat_equation.py mode=infer
77+
78+
# heat_equation_bc
79+
# linux
80+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyGeoNet/heat_equation_bc.npz -P ./data/
81+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyGeoNet/heat_equation_bc_test.npz -P ./data/
82+
83+
# windows
84+
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyGeoNet/heat_equation_bc.npz --create-dirs -o ./data/heat_equation.npz
85+
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyGeoNet/heat_equation_bc_test.npz --create-dirs -o ./data/heat_equation.npz
86+
87+
python heat_equation_with_bc.py mode=infer
88+
```
89+
5690
| 模型 | mRes | ev |
5791
| :-- | :-- | :-- |
5892
| [heat_equation_pretrain.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/PhyGeoNet/heat_equation_pretrain.pdparams) | 0.815 |0.095|

examples/phygeonet/conf/heat_equation.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,21 @@ TRAIN:
5151
EVAL:
5252
pretrained_model_path: null
5353
eval_with_no_grad: true
54+
55+
# inference settings
56+
INFER:
57+
pretrained_model_path: 'https://paddle-org.bj.bcebos.com/paddlescience/models/PhyGeoNet/heat_equation_pretrain.pdparams'
58+
export_path: ./inference/heat_equation
59+
pdmodel_path: ${INFER.export_path}.pdmodel
60+
pdiparams_path: ${INFER.export_path}.pdiparams
61+
onnx_path: ${INFER.export_path}.onnx
62+
device: gpu
63+
engine: native
64+
precision: fp32
65+
ir_optim: true
66+
min_subgraph_size: 5
67+
gpu_mem: 20
68+
gpu_id: 0
69+
max_batch_size: 256
70+
num_cpu_threads: 10
71+
batch_size: 256

examples/phygeonet/conf/heat_equation_with_bc.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ hydra:
2424

2525
# general settings
2626
mode: train # running mode: train/eval
27+
log_freq: 50
2728
seed: 66
2829
data_dir: ./data/heat_equation_bc.npz
2930
test_data_dir: ./data/heat_equation_bc_test.npz
@@ -53,3 +54,21 @@ TRAIN:
5354
EVAL:
5455
pretrained_model_path: null
5556
eval_with_no_grad: true
57+
58+
# inference settings
59+
INFER:
60+
pretrained_model_path: 'https://paddle-org.bj.bcebos.com/paddlescience/models/PhyGeoNet/heat_equation_bc_pretrain.pdparams'
61+
export_path: ./inference/heat_equation_bc
62+
pdmodel_path: ${INFER.export_path}.pdmodel
63+
pdiparams_path: ${INFER.export_path}.pdiparams
64+
onnx_path: ${INFER.export_path}.onnx
65+
device: gpu
66+
engine: native
67+
precision: fp32
68+
ir_optim: true
69+
min_subgraph_size: 5
70+
gpu_mem: 20
71+
gpu_id: 0
72+
max_batch_size: 256
73+
num_cpu_threads: 10
74+
batch_size: 256

examples/phygeonet/heat_equation.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os.path as osp
12
from typing import Dict
23

34
import hydra
@@ -153,14 +154,102 @@ def evaluate(cfg: DictConfig):
153154
plt.close(fig)
154155

155156

157+
def export(cfg: DictConfig):
158+
model = ppsci.arch.USCNN(**cfg.MODEL)
159+
# initialize solver
160+
solver = ppsci.solver.Solver(
161+
model,
162+
pretrained_model_path=cfg.INFER.pretrained_model_path,
163+
)
164+
# export model
165+
from paddle.static import InputSpec
166+
167+
input_spec = [
168+
{
169+
key: InputSpec([None, 2, 19, 84], "float32", name=key)
170+
for key in model.input_keys
171+
},
172+
]
173+
solver.export(input_spec, cfg.INFER.export_path)
174+
175+
176+
def inference(cfg: DictConfig):
177+
from deploy.python_infer import pinn_predictor
178+
179+
predictor = pinn_predictor.PINNPredictor(cfg)
180+
data = np.load(cfg.data_dir)
181+
coords = data["coords"]
182+
ofv_sb = data["OFV_sb"]
183+
184+
## create model
185+
pad_singleside = cfg.MODEL.pad_singleside
186+
input_spec = {"coords": coords}
187+
188+
output_v = predictor.predict(input_spec, cfg.INFER.batch_size)
189+
# mapping data to cfg.INFER.output_keys
190+
output_v = {
191+
store_key: output_v[infer_key]
192+
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_v.keys())
193+
}
194+
195+
output_v = output_v["output_v"]
196+
197+
output_v[0, 0, -pad_singleside:, pad_singleside:-pad_singleside] = 0
198+
output_v[0, 0, :pad_singleside, pad_singleside:-pad_singleside] = 1
199+
output_v[0, 0, pad_singleside:-pad_singleside, -pad_singleside:] = 1
200+
output_v[0, 0, pad_singleside:-pad_singleside, 0:pad_singleside] = 1
201+
output_v[0, 0, 0, 0] = 0.5 * (output_v[0, 0, 0, 1] + output_v[0, 0, 1, 0])
202+
output_v[0, 0, 0, -1] = 0.5 * (output_v[0, 0, 0, -2] + output_v[0, 0, 1, -1])
203+
204+
ev = paddle.sqrt(
205+
paddle.mean((ofv_sb - output_v[0, 0]) ** 2) / paddle.mean(ofv_sb**2)
206+
).item()
207+
logger.info(f"ev: {ev}")
208+
209+
fig = plt.figure()
210+
ax = plt.subplot(1, 2, 1)
211+
utils.visualize(
212+
ax,
213+
coords[0, 0, 1:-1, 1:-1],
214+
coords[0, 1, 1:-1, 1:-1],
215+
output_v[0, 0, 1:-1, 1:-1],
216+
"horizontal",
217+
[0, 1],
218+
)
219+
utils.set_axis_label(ax, "p")
220+
ax.set_title("CNN " + r"$T$")
221+
ax.set_aspect("equal")
222+
ax = plt.subplot(1, 2, 2)
223+
utils.visualize(
224+
ax,
225+
coords[0, 0, 1:-1, 1:-1],
226+
coords[0, 1, 1:-1, 1:-1],
227+
ofv_sb[1:-1, 1:-1],
228+
"horizontal",
229+
[0, 1],
230+
)
231+
utils.set_axis_label(ax, "p")
232+
ax.set_aspect("equal")
233+
ax.set_title("FV " + r"$T$")
234+
fig.tight_layout(pad=1)
235+
fig.savefig(osp.join(cfg.output_dir, "result.png"), bbox_inches="tight")
236+
plt.close(fig)
237+
238+
156239
@hydra.main(version_base=None, config_path="./conf", config_name="heat_equation.yaml")
157240
def main(cfg: DictConfig):
158241
if cfg.mode == "train":
159242
train(cfg)
160243
elif cfg.mode == "eval":
161244
evaluate(cfg)
245+
elif cfg.mode == "export":
246+
export(cfg)
247+
elif cfg.mode == "infer":
248+
inference(cfg)
162249
else:
163-
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
250+
raise ValueError(
251+
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
252+
)
164253

165254

166255
if __name__ == "__main__":

examples/phygeonet/heat_equation_with_bc.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,119 @@ def evaluate(cfg: DictConfig):
188188
plt.close(fig1)
189189

190190

191+
def export(cfg: DictConfig):
192+
model = ppsci.arch.USCNN(**cfg.MODEL)
193+
# initialize solver
194+
solver = ppsci.solver.Solver(
195+
model,
196+
pretrained_model_path=cfg.INFER.pretrained_model_path,
197+
)
198+
# export model
199+
from paddle.static import InputSpec
200+
201+
input_spec = [
202+
{
203+
key: InputSpec([None, 1, 19, 84], "float32", name=key)
204+
for key in model.input_keys
205+
},
206+
]
207+
solver.export(input_spec, cfg.INFER.export_path)
208+
209+
210+
def inference(cfg: DictConfig):
211+
from deploy.python_infer import pinn_predictor
212+
213+
predictor = pinn_predictor.PINNPredictor(cfg)
214+
pad_singleside = cfg.MODEL.pad_singleside
215+
216+
data = np.load(cfg.test_data_dir)
217+
paras = data["paras"]
218+
truths = data["truths"]
219+
coords = data["coords"]
220+
221+
paras = paras.reshape([paras.shape[0], 1, paras.shape[1], paras.shape[2]])
222+
input_spec = {"coords": paras}
223+
output_v = predictor.predict(input_spec, cfg.INFER.batch_size)
224+
# mapping data to cfg.INFER.output_keys
225+
output_v = {
226+
store_key: output_v[infer_key]
227+
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_v.keys())
228+
}
229+
output_v = output_v["output_v"]
230+
num_sample = output_v.shape[0]
231+
for j in range(num_sample):
232+
# Impose BC
233+
output_v[j, 0, -pad_singleside:, pad_singleside:-pad_singleside] = output_v[
234+
j, 0, 1:2, pad_singleside:-pad_singleside
235+
]
236+
output_v[j, 0, :pad_singleside, pad_singleside:-pad_singleside] = output_v[
237+
j, 0, -2:-1, pad_singleside:-pad_singleside
238+
]
239+
output_v[j, 0, :, -pad_singleside:] = 0
240+
output_v[j, 0, :, 0:pad_singleside] = paras[j, 0, 0, 0]
241+
242+
error = paddle.sqrt(
243+
paddle.mean((truths - output_v) ** 2) / paddle.mean(truths**2)
244+
).item()
245+
logger.info(f"The average error: {error / num_sample}")
246+
247+
output_vs = output_v
248+
PARALIST = [1, 2, 3, 4, 5, 6, 7]
249+
for i in range(len(PARALIST)):
250+
truth = truths[i]
251+
coord = coords[i]
252+
output_v = output_vs[i]
253+
truth = truth.reshape(1, 1, truth.shape[0], truth.shape[1])
254+
coord = coord.reshape(1, 2, coord.shape[2], coord.shape[3])
255+
fig1 = plt.figure()
256+
xylabelsize = 20
257+
xytickssize = 20
258+
titlesize = 20
259+
ax = plt.subplot(1, 2, 1)
260+
_, cbar = utils.visualize(
261+
ax,
262+
coord[0, 0, :, :],
263+
coord[0, 1, :, :],
264+
output_v[0, :, :],
265+
"horizontal",
266+
[0, max(PARALIST)],
267+
)
268+
ax.set_aspect("equal")
269+
utils.set_axis_label(ax, "p")
270+
ax.set_title("PhyGeoNet " + r"$T$", fontsize=titlesize)
271+
ax.set_xlabel(xlabel=r"$x$", fontsize=xylabelsize)
272+
ax.set_ylabel(ylabel=r"$y$", fontsize=xylabelsize)
273+
ax.set_xticks([-1, 0, 1])
274+
ax.set_yticks([-1, 0, 1])
275+
ax.tick_params(axis="x", labelsize=xytickssize)
276+
ax.tick_params(axis="y", labelsize=xytickssize)
277+
cbar.set_ticks([0, 1, 2, 3, 4, 5, 6, 7])
278+
cbar.ax.tick_params(labelsize=xytickssize)
279+
ax = plt.subplot(1, 2, 2)
280+
_, cbar = utils.visualize(
281+
ax,
282+
coord[0, 0, :, :],
283+
coord[0, 1, :, :],
284+
truth[0, 0, :, :],
285+
"horizontal",
286+
[0, max(PARALIST)],
287+
)
288+
ax.set_aspect("equal")
289+
utils.set_axis_label(ax, "p")
290+
ax.set_title("FV " + r"$T$", fontsize=titlesize)
291+
ax.set_xlabel(xlabel=r"$x$", fontsize=xylabelsize)
292+
ax.set_ylabel(ylabel=r"$y$", fontsize=xylabelsize)
293+
ax.set_xticks([-1, 0, 1])
294+
ax.set_yticks([-1, 0, 1])
295+
ax.tick_params(axis="x", labelsize=xytickssize)
296+
ax.tick_params(axis="y", labelsize=xytickssize)
297+
cbar.set_ticks([0, 1, 2, 3, 4, 5, 6, 7])
298+
cbar.ax.tick_params(labelsize=xytickssize)
299+
fig1.tight_layout(pad=1)
300+
fig1.savefig(osp.join(cfg.output_dir, f"Para{i}T.png"), bbox_inches="tight")
301+
plt.close(fig1)
302+
303+
191304
@hydra.main(
192305
version_base=None, config_path="./conf", config_name="heat_equation_with_bc.yaml"
193306
)
@@ -196,8 +309,14 @@ def main(cfg: DictConfig):
196309
train(cfg)
197310
elif cfg.mode == "eval":
198311
evaluate(cfg)
312+
elif cfg.mode == "export":
313+
export(cfg)
314+
elif cfg.mode == "infer":
315+
inference(cfg)
199316
else:
200-
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
317+
raise ValueError(
318+
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
319+
)
201320

202321

203322
if __name__ == "__main__":

0 commit comments

Comments
 (0)