Skip to content

Commit 7930574

Browse files
pretty evaluation output with prettytable
1 parent 01bf607 commit 7930574

File tree

3 files changed

+22
-4
lines changed

3 files changed

+22
-4
lines changed

examples/darcy/darcy2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def poisson_ref_compute_func(_in):
246246
cfg.NPOINT_PDE + cfg.NPOINT_BC, evenly=True
247247
)
248248
visualizer = {
249-
"visualize_p": ppsci.visualize.VisualizerVtu(
249+
"visualize_p_ux_uy": ppsci.visualize.VisualizerVtu(
250250
vis_points,
251251
{
252252
"p": lambda d: d["p"],

ppsci/solver/solver.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import numpy as np
3030
import paddle
3131
import paddle.distributed as dist
32+
import prettytable
3233
import sympy as sp
3334
import visualdl as vdl
3435
from packaging import version
@@ -461,10 +462,21 @@ def eval(self, epoch_id: int = 0) -> Tuple[float, Dict[str, Dict[str, float]]]:
461462
self.eval_func = ppsci.solver.eval.eval_func
462463

463464
result = self.eval_func(self, epoch_id, self.log_freq)
464-
metric_msg = ", ".join(
465-
[self.eval_output_info[key].avg_info for key in self.eval_output_info]
465+
metric_table = prettytable.PrettyTable(
466+
["Name", "Value"],
467+
title=f"Evaluation Metric(s){'' if epoch_id == 0 else f' at epoch {epoch_id}'}",
468+
align="l",
466469
)
467-
logger.info(f"[Eval][Epoch {epoch_id}][Avg] {metric_msg}")
470+
471+
loss_msg = []
472+
for name, value in self.eval_output_info.items():
473+
if name.startswith("loss"):
474+
loss_msg.append(f"{name}: {value.avg_fmt}")
475+
else:
476+
metric_table.add_row([name, value.avg_fmt])
477+
loss_msg = ", ".join(loss_msg)
478+
479+
logger.info(f"[Eval][Epoch {epoch_id}][Avg] {loss_msg}\n{metric_table}")
468480
self.eval_output_info.clear()
469481

470482
return result

ppsci/utils/misc.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ def avg_info(self):
9090
self.avg = float(self.avg)
9191
return f"{self.name}: {self.avg:.5f}"
9292

93+
@property
94+
def avg_fmt(self):
95+
if isinstance(self.avg, paddle.Tensor):
96+
self.avg = float(self.avg)
97+
return f"{self.avg:.5e}"
98+
9399
@property
94100
def total(self):
95101
return f"{self.name}_sum: {self.sum:{self.fmt}}{self.postfix}"

0 commit comments

Comments
 (0)