Skip to content

Commit 66cded1

Browse files
Merge branch 'pretty_eval_output' into develop
2 parents da865dd + 7930574 commit 66cded1

File tree

3 files changed

+22
-4
lines changed

3 files changed

+22
-4
lines changed

examples/darcy/darcy2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def poisson_ref_compute_func(_in):
246246
cfg.NPOINT_PDE + cfg.NPOINT_BC, evenly=True
247247
)
248248
visualizer = {
249-
"visualize_p": ppsci.visualize.VisualizerVtu(
249+
"visualize_p_ux_uy": ppsci.visualize.VisualizerVtu(
250250
vis_points,
251251
{
252252
"p": lambda d: d["p"],

ppsci/solver/solver.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import numpy as np
3232
import paddle
3333
import paddle.distributed as dist
34+
import prettytable
3435
import sympy as sp
3536
import visualdl as vdl
3637
from packaging import version
@@ -545,10 +546,21 @@ def eval(self, epoch_id: int = 0) -> Tuple[float, Dict[str, Dict[str, float]]]:
545546
self.eval_func = ppsci.solver.eval.eval_func
546547

547548
result = self.eval_func(self, epoch_id, self.log_freq)
548-
metric_msg = ", ".join(
549-
[self.eval_output_info[key].avg_info for key in self.eval_output_info]
549+
metric_table = prettytable.PrettyTable(
550+
["Name", "Value"],
551+
title=f"Evaluation Metric(s){'' if epoch_id == 0 else f' at epoch {epoch_id}'}",
552+
align="l",
550553
)
551-
logger.info(f"[Eval][Epoch {epoch_id}][Avg] {metric_msg}")
554+
555+
loss_msg = []
556+
for name, value in self.eval_output_info.items():
557+
if name.startswith("loss"):
558+
loss_msg.append(f"{name}: {value.avg_fmt}")
559+
else:
560+
metric_table.add_row([name, value.avg_fmt])
561+
loss_msg = ", ".join(loss_msg)
562+
563+
logger.info(f"[Eval][Epoch {epoch_id}][Avg] {loss_msg}\n{metric_table}")
552564
self.eval_output_info.clear()
553565

554566
return result

ppsci/utils/misc.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ def avg_info(self):
9090
self.avg = float(self.avg)
9191
return f"{self.name}: {self.avg:.5f}"
9292

93+
@property
94+
def avg_fmt(self):
95+
if isinstance(self.avg, paddle.Tensor):
96+
self.avg = float(self.avg)
97+
return f"{self.avg:.5e}"
98+
9399
@property
94100
def total(self):
95101
return f"{self.name}_sum: {self.sum:{self.fmt}}{self.postfix}"

0 commit comments

Comments
 (0)