diff --git a/docs/zh/examples/hpinns.md b/docs/zh/examples/hpinns.md
index dc98d6da10..4ecc23a46d 100644
--- a/docs/zh/examples/hpinns.md
+++ b/docs/zh/examples/hpinns.md
@@ -45,7 +45,6 @@
```
-
| 预训练模型 | 指标 |
|:--| :--|
| [hpinns_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/hPINNs/hpinns_pretrained.pdparams) | loss(opt_sup): 0.05352
MSE.eval_metric(opt_sup): 0.00002
loss(val_sup): 0.02205
MSE.eval_metric(val_sup): 0.00001 |
diff --git a/examples/hpinns/holography.py b/examples/hpinns/holography.py
index fd7ea5b784..290f2b3c2a 100644
--- a/examples/hpinns/holography.py
+++ b/examples/hpinns/holography.py
@@ -460,12 +460,19 @@ def inference(cfg: DictConfig):
for store_key, infer_key in zip(cfg.INFER.output_keys, output_dict.keys())
}
- ppsci.visualize.save_vtu_from_dict(
- "./hpinns_pred.vtu",
- {**input_dict, **output_dict},
- input_dict.keys(),
- cfg.INFER.output_keys,
+ # plotting E and eps
+ N = ((func_module.l_BOX[1] - func_module.l_BOX[0]) / 0.05).astype(int)
+ input_eval = np.stack((input_dict["x"], input_dict["y"]), axis=-1).reshape(
+ N[0], N[1], 2
)
+ e_re = output_dict["e_re"].reshape(N[0], N[1])
+ e_im = output_dict["e_im"].reshape(N[0], N[1])
+ eps = output_dict["eps"].reshape(N[0], N[1])
+ v_visual = e_re**2 + e_im**2
+ field_visual = np.stack((v_visual, eps), axis=-1)
+ plot_module.field_name = ["Fig7_E", "Fig7_eps"]
+ plot_module.FIGNAME = "hpinns_pred"
+ plot_module.plot_field_holo(input_eval, field_visual)
@hydra.main(version_base=None, config_path="./conf", config_name="hpinns.yaml")
diff --git a/examples/hpinns/plotting.py b/examples/hpinns/plotting.py
index 9cd0c8a096..c7f958b0ae 100644
--- a/examples/hpinns/plotting.py
+++ b/examples/hpinns/plotting.py
@@ -20,6 +20,7 @@
from typing import Callable
from typing import Dict
from typing import List
+from typing import Optional
import functions as func_module
import matplotlib.pyplot as plt
@@ -114,16 +115,16 @@ def prepare_data(solver: ppsci.solver.Solver, expr_dict: Dict[str, Callable]):
def plot_field_holo(
coord_visual: np.ndarray,
field_visual: np.ndarray,
- coord_lambda: np.ndarray,
- field_lambda: np.ndarray,
+ coord_lambda: Optional[np.ndarray] = None,
+ field_lambda: Optional[np.ndarray] = None,
):
"""Plot fields of of holography example.
Args:
coord_visual (np.ndarray): The coord of epsilon and |E|**2.
field_visual (np.ndarray): The filed of epsilon and |E|**2.
- coord_lambda (np.ndarray): The coord of lambda.
- field_lambda (np.ndarray): The filed of lambda.
+ coord_lambda (Optional[np.ndarray], optional): The coord of lambda. Defaults to None.
+ field_lambda (Optional[np.ndarray], optional): The filed of lambda. Defaults to None.
"""
fmin, fmax = np.array([0, 1.0]), np.array([0.6, 12])
cmin, cmax = coord_visual.min(axis=(0, 1)), coord_visual.max(axis=(0, 1))
@@ -168,7 +169,7 @@ def plot_field_holo(
cb = plt.colorbar()
plt.axis((emin[0], emax[0], emin[1], emax[1]))
plt.clim(vmin=fmin[fi], vmax=fmax[fi])
- else:
+ elif coord_lambda is not None and field_lambda is not None:
# Fig_6C_lambda_
plt.figure(fi * 100 + 101, figsize=(8, 6))
plt.clf()