Skip to content

Commit

Permalink
【PPSCI Export&Infer No.31】heat_pinn (#926)
Browse files Browse the repository at this point in the history
* ppsci.equation.PDE.parameters/state_dict/set_state_dict api fix

* ppsci.equation.PDE.parameters/state_dict/set_state_dict api fix

* fix api docs in the timedomain

* fix api docs of timedomain

* fix api docs of timedomain

* ppsci api docs fixed

* ppsci api docs fixed

* ppsci api docs fixed

* add export and infer for bracket

* updata bracket doc

* solve conflict according to the branch named develop

* Update examples/bracket/conf/bracket.yaml

* Update examples/bracket/conf/bracket.yaml

* Update examples/bracket/conf/bracket.yaml

* add export&inference for bracket

* add export and infer for heat_pinn

* add export and infer for heat_pinn

* Update examples/heat_pinn/heat_pinn.py

* Update examples/heat_pinn/heat_pinn.py

* Update examples/heat_pinn/conf/heat_pinn.yaml

---------

Co-authored-by: krp <2934631798@qq.com>
Co-authored-by: HydrogenSulfate <490868991@qq.com>
  • Loading branch information
3 people authored Jun 14, 2024
1 parent 5edd8f2 commit e54915d
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 127 deletions.
12 changes: 12 additions & 0 deletions docs/zh/examples/heat_pinn.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@
python heat_pinn.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/heat_pinn/heat_pinn_pretrained.pdparams
```

=== "模型导出命令"

``` sh
python heat_pinn.py mode=export
```

=== "模型推理命令"

``` sh
python heat_pinn.py mode=infer
```

| 预训练模型 | 指标 |
|:--| :--|
| [heat_pinn_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/heat_pinn/heat_pinn_pretrained.pdparams) | norm MSE loss between the FDM and PINN is 1.30174e-03 |
Expand Down
20 changes: 19 additions & 1 deletion examples/heat_pinn/conf/heat_pinn.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
hydra:
run:
# dynamic output directory according to running time and override name
dir: outputs_bracket/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
dir: outputs_heat_pinn/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
job:
name: ${mode} # name of logfile
chdir: false # keep current working directory unchanged
Expand Down Expand Up @@ -50,3 +50,21 @@ TRAIN:
# evaluation settings
EVAL:
pretrained_model_path: null

# inference settings
INFER:
pretrained_model_path: "https://paddle-org.bj.bcebos.com/paddlescience/models/heat_pinn/heat_pinn_pretrained.pdparams"
export_path: ./inference/heat_pinn
pdmodel_path: ${INFER.export_path}.pdmodel
pdiparams_path: ${INFER.export_path}.pdiparams
device: gpu
engine: native
precision: fp32
onnx_path: ${INFER.export_path}.onnx
ir_optim: true
min_subgraph_size: 10
gpu_mem: 2000
gpu_id: 0
max_batch_size: 128
num_cpu_threads: 4
batch_size: 128
238 changes: 112 additions & 126 deletions examples/heat_pinn/heat_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,74 @@
from ppsci.utils import logger


def plot(input_data, N_EVAL, pinn_output, fdm_output, cfg):
x = input_data["x"].reshape(N_EVAL, N_EVAL)
y = input_data["y"].reshape(N_EVAL, N_EVAL)

plt.subplot(2, 1, 1)
plt.pcolormesh(x, y, pinn_output * 75.0, cmap="magma")
plt.colorbar()
plt.title("PINN")
plt.xlabel("x")
plt.ylabel("y")
plt.tight_layout()
plt.axis("square")

plt.subplot(2, 1, 2)
plt.pcolormesh(x, y, fdm_output, cmap="magma")
plt.colorbar()
plt.xlabel("x")
plt.ylabel("y")
plt.title("FDM")
plt.tight_layout()
plt.axis("square")
plt.savefig(osp.join(cfg.output_dir, "pinn_fdm_comparison.png"))
plt.close()

frames_val = np.array([-0.75, -0.5, -0.25, 0.0, +0.25, +0.5, +0.75])
frames = [*map(int, (frames_val + 1) / 2 * (N_EVAL - 1))]
height = 3
plt.figure("", figsize=(len(frames) * height, 2 * height))

for i, var_index in enumerate(frames):
plt.subplot(2, len(frames), i + 1)
plt.title(f"y = {frames_val[i]:.2f}")
plt.plot(
x[:, var_index],
pinn_output[:, var_index] * 75.0,
"r--",
lw=4.0,
label="pinn",
)
plt.plot(x[:, var_index], fdm_output[:, var_index], "b", lw=2.0, label="FDM")
plt.ylim(0.0, 100.0)
plt.xlim(-1.0, +1.0)
plt.xlabel("x")
plt.ylabel("T")
plt.tight_layout()
plt.legend()

for i, var_index in enumerate(frames):
plt.subplot(2, len(frames), len(frames) + i + 1)
plt.title(f"x = {frames_val[i]:.2f}")
plt.plot(
y[var_index, :],
pinn_output[var_index, :] * 75.0,
"r--",
lw=4.0,
label="pinn",
)
plt.plot(y[var_index, :], fdm_output[var_index, :], "b", lw=2.0, label="FDM")
plt.ylim(0.0, 100.0)
plt.xlim(-1.0, +1.0)
plt.xlabel("y")
plt.ylabel("T")
plt.tight_layout()
plt.legend()

plt.savefig(osp.join(cfg.output_dir, "profiles.png"))


def train(cfg: DictConfig):
# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(cfg.seed)
Expand Down Expand Up @@ -141,72 +209,7 @@ def train(cfg: DictConfig):
fdm_output = fdm.solve(N_EVAL, 1).T
mse_loss = np.mean(np.square(pinn_output - (fdm_output / 75.0)))
logger.info(f"The norm MSE loss between the FDM and PINN is {mse_loss}")

x = input_data["x"].reshape(N_EVAL, N_EVAL)
y = input_data["y"].reshape(N_EVAL, N_EVAL)

plt.subplot(2, 1, 1)
plt.pcolormesh(x, y, pinn_output * 75.0, cmap="magma")
plt.colorbar()
plt.title("PINN")
plt.xlabel("x")
plt.ylabel("y")
plt.tight_layout()
plt.axis("square")

plt.subplot(2, 1, 2)
plt.pcolormesh(x, y, fdm_output, cmap="magma")
plt.colorbar()
plt.xlabel("x")
plt.ylabel("y")
plt.title("FDM")
plt.tight_layout()
plt.axis("square")
plt.savefig(osp.join(cfg.output_dir, "pinn_fdm_comparison.png"))
plt.close()

frames_val = np.array([-0.75, -0.5, -0.25, 0.0, +0.25, +0.5, +0.75])
frames = [*map(int, (frames_val + 1) / 2 * (N_EVAL - 1))]
height = 3
plt.figure("", figsize=(len(frames) * height, 2 * height))

for i, var_index in enumerate(frames):
plt.subplot(2, len(frames), i + 1)
plt.title(f"y = {frames_val[i]:.2f}")
plt.plot(
x[:, var_index],
pinn_output[:, var_index] * 75.0,
"r--",
lw=4.0,
label="pinn",
)
plt.plot(x[:, var_index], fdm_output[:, var_index], "b", lw=2.0, label="FDM")
plt.ylim(0.0, 100.0)
plt.xlim(-1.0, +1.0)
plt.xlabel("x")
plt.ylabel("T")
plt.tight_layout()
plt.legend()

for i, var_index in enumerate(frames):
plt.subplot(2, len(frames), len(frames) + i + 1)
plt.title(f"x = {frames_val[i]:.2f}")
plt.plot(
y[var_index, :],
pinn_output[var_index, :] * 75.0,
"r--",
lw=4.0,
label="pinn",
)
plt.plot(y[var_index, :], fdm_output[var_index, :], "b", lw=2.0, label="FDM")
plt.ylim(0.0, 100.0)
plt.xlim(-1.0, +1.0)
plt.xlabel("y")
plt.ylabel("T")
plt.tight_layout()
plt.legend()

plt.savefig(osp.join(cfg.output_dir, "profiles.png"))
plot(input_data, N_EVAL, pinn_output, fdm_output, cfg)


def evaluate(cfg: DictConfig):
Expand Down Expand Up @@ -239,72 +242,49 @@ def evaluate(cfg: DictConfig):
fdm_output = fdm.solve(N_EVAL, 1).T
mse_loss = np.mean(np.square(pinn_output - (fdm_output / 75.0)))
logger.info(f"The norm MSE loss between the FDM and PINN is {mse_loss:.5e}")
plot(input_data, N_EVAL, pinn_output, fdm_output, cfg)

x = input_data["x"].reshape(N_EVAL, N_EVAL)
y = input_data["y"].reshape(N_EVAL, N_EVAL)

plt.subplot(2, 1, 1)
plt.pcolormesh(x, y, pinn_output * 75.0, cmap="magma")
plt.colorbar()
plt.title("PINN")
plt.xlabel("x")
plt.ylabel("y")
plt.tight_layout()
plt.axis("square")
def export(cfg: DictConfig):
# set model
model = ppsci.arch.MLP(**cfg.MODEL)

plt.subplot(2, 1, 2)
plt.pcolormesh(x, y, fdm_output, cmap="magma")
plt.colorbar()
plt.xlabel("x")
plt.ylabel("y")
plt.title("FDM")
plt.tight_layout()
plt.axis("square")
plt.savefig(osp.join(cfg.output_dir, "pinn_fdm_comparison.png"))
plt.close()
# initialize solver
solver = ppsci.solver.Solver(
model,
cfg=cfg,
)
# export model
from paddle.static import InputSpec

frames_val = np.array([-0.75, -0.5, -0.25, 0.0, +0.25, +0.5, +0.75])
frames = [*map(int, (frames_val + 1) / 2 * (N_EVAL - 1))]
height = 3
plt.figure("", figsize=(len(frames) * height, 2 * height))
input_spec = [
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
]
solver.export(input_spec, cfg.INFER.export_path)

for i, var_index in enumerate(frames):
plt.subplot(2, len(frames), i + 1)
plt.title(f"y = {frames_val[i]:.2f}")
plt.plot(
x[:, var_index],
pinn_output[:, var_index] * 75.0,
"r--",
lw=4.0,
label="pinn",
)
plt.plot(x[:, var_index], fdm_output[:, var_index], "b", lw=2.0, label="FDM")
plt.ylim(0.0, 100.0)
plt.xlim(-1.0, +1.0)
plt.xlabel("x")
plt.ylabel("T")
plt.tight_layout()
plt.legend()

for i, var_index in enumerate(frames):
plt.subplot(2, len(frames), len(frames) + i + 1)
plt.title(f"x = {frames_val[i]:.2f}")
plt.plot(
y[var_index, :],
pinn_output[var_index, :] * 75.0,
"r--",
lw=4.0,
label="pinn",
)
plt.plot(y[var_index, :], fdm_output[var_index, :], "b", lw=2.0, label="FDM")
plt.ylim(0.0, 100.0)
plt.xlim(-1.0, +1.0)
plt.xlabel("y")
plt.ylabel("T")
plt.tight_layout()
plt.legend()
def inference(cfg: DictConfig):
from deploy.python_infer import pinn_predictor

plt.savefig(osp.join(cfg.output_dir, "profiles.png"))
predictor = pinn_predictor.PINNPredictor(cfg)
# set geometry
geom = {"rect": ppsci.geometry.Rectangle((-1.0, -1.0), (1.0, 1.0))}
# begin eval
N_EVAL = 100
input_data = geom["rect"].sample_interior(N_EVAL**2, evenly=True)
output_data = predictor.predict(
{key: input_data[key] for key in cfg.MODEL.input_keys}, cfg.INFER.batch_size
)

# mapping data to cfg.INFER.output_keys
output_data = {
store_key: output_data[infer_key]
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_data.keys())
}["u"].reshape(N_EVAL, N_EVAL)
fdm_output = fdm.solve(N_EVAL, 1).T
mse_loss = np.mean(np.square(output_data - (fdm_output / 75.0)))
logger.info(f"The norm MSE loss between the FDM and PINN is {mse_loss:.5e}")
plot(input_data, N_EVAL, output_data, fdm_output, cfg)


@hydra.main(version_base=None, config_path="./conf", config_name="heat_pinn.yaml")
Expand All @@ -313,8 +293,14 @@ def main(cfg: DictConfig):
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
elif cfg.mode == "export":
export(cfg)
elif cfg.mode == "infer":
inference(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
raise ValueError(
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
)


if __name__ == "__main__":
Expand Down

0 comments on commit e54915d

Please sign in to comment.