diff --git a/docs/zh/examples/hpinns.md b/docs/zh/examples/hpinns.md index dc98d6da10..4ecc23a46d 100644 --- a/docs/zh/examples/hpinns.md +++ b/docs/zh/examples/hpinns.md @@ -45,7 +45,6 @@ ``` - | 预训练模型 | 指标 | |:--| :--| | [hpinns_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/hPINNs/hpinns_pretrained.pdparams) | loss(opt_sup): 0.05352
MSE.eval_metric(opt_sup): 0.00002
loss(val_sup): 0.02205
MSE.eval_metric(val_sup): 0.00001 | diff --git a/examples/hpinns/holography.py b/examples/hpinns/holography.py index fd7ea5b784..290f2b3c2a 100644 --- a/examples/hpinns/holography.py +++ b/examples/hpinns/holography.py @@ -460,12 +460,19 @@ def inference(cfg: DictConfig): for store_key, infer_key in zip(cfg.INFER.output_keys, output_dict.keys()) } - ppsci.visualize.save_vtu_from_dict( - "./hpinns_pred.vtu", - {**input_dict, **output_dict}, - input_dict.keys(), - cfg.INFER.output_keys, + # plotting E and eps + N = ((func_module.l_BOX[1] - func_module.l_BOX[0]) / 0.05).astype(int) + input_eval = np.stack((input_dict["x"], input_dict["y"]), axis=-1).reshape( + N[0], N[1], 2 ) + e_re = output_dict["e_re"].reshape(N[0], N[1]) + e_im = output_dict["e_im"].reshape(N[0], N[1]) + eps = output_dict["eps"].reshape(N[0], N[1]) + v_visual = e_re**2 + e_im**2 + field_visual = np.stack((v_visual, eps), axis=-1) + plot_module.field_name = ["Fig7_E", "Fig7_eps"] + plot_module.FIGNAME = "hpinns_pred" + plot_module.plot_field_holo(input_eval, field_visual) @hydra.main(version_base=None, config_path="./conf", config_name="hpinns.yaml") diff --git a/examples/hpinns/plotting.py b/examples/hpinns/plotting.py index 9cd0c8a096..c7f958b0ae 100644 --- a/examples/hpinns/plotting.py +++ b/examples/hpinns/plotting.py @@ -20,6 +20,7 @@ from typing import Callable from typing import Dict from typing import List +from typing import Optional import functions as func_module import matplotlib.pyplot as plt @@ -114,16 +115,16 @@ def prepare_data(solver: ppsci.solver.Solver, expr_dict: Dict[str, Callable]): def plot_field_holo( coord_visual: np.ndarray, field_visual: np.ndarray, - coord_lambda: np.ndarray, - field_lambda: np.ndarray, + coord_lambda: Optional[np.ndarray] = None, + field_lambda: Optional[np.ndarray] = None, ): """Plot fields of of holography example. Args: coord_visual (np.ndarray): The coord of epsilon and |E|**2. field_visual (np.ndarray): The filed of epsilon and |E|**2. - coord_lambda (np.ndarray): The coord of lambda. - field_lambda (np.ndarray): The filed of lambda. + coord_lambda (Optional[np.ndarray], optional): The coord of lambda. Defaults to None. + field_lambda (Optional[np.ndarray], optional): The filed of lambda. Defaults to None. """ fmin, fmax = np.array([0, 1.0]), np.array([0.6, 12]) cmin, cmax = coord_visual.min(axis=(0, 1)), coord_visual.max(axis=(0, 1)) @@ -168,7 +169,7 @@ def plot_field_holo( cb = plt.colorbar() plt.axis((emin[0], emax[0], emin[1], emax[1])) plt.clim(vmin=fmin[fi], vmax=fmax[fi]) - else: + elif coord_lambda is not None and field_lambda is not None: # Fig_6C_lambda_ plt.figure(fi * 100 + 101, figsize=(8, 6)) plt.clf()