Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/darcy/darcy2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def poisson_ref_compute_func(_in):
cfg.NPOINT_PDE + cfg.NPOINT_BC, evenly=True
)
visualizer = {
"visualize_p": ppsci.visualize.VisualizerVtu(
"visualize_p_ux_uy": ppsci.visualize.VisualizerVtu(
vis_points,
{
"p": lambda d: d["p"],
Expand Down Expand Up @@ -246,7 +246,7 @@ def poisson_ref_compute_func(_in):
cfg.NPOINT_PDE + cfg.NPOINT_BC, evenly=True
)
visualizer = {
"visualize_p": ppsci.visualize.VisualizerVtu(
"visualize_p_ux_uy": ppsci.visualize.VisualizerVtu(
vis_points,
{
"p": lambda d: d["p"],
Expand Down
13 changes: 7 additions & 6 deletions ppsci/solver/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import time
from typing import TYPE_CHECKING
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import Union

Expand Down Expand Up @@ -59,7 +60,7 @@ def _get_dataset_length(


def _eval_by_dataset(
solver: "solver.Solver", epoch_id: int, log_freq: int
solver: "solver.Solver", epoch_id: Optional[int], log_freq: int
) -> Tuple[float, Dict[str, Dict[str, float]]]:
"""Evaluate with computing metric on total samples(default process).

Expand All @@ -68,7 +69,7 @@ def _eval_by_dataset(

Args:
solver (solver.Solver): Main Solver.
epoch_id (int): Epoch id.
epoch_id (Optional[int]): Epoch id.
log_freq (int): Log evaluation information every `log_freq` steps.

Returns:
Expand Down Expand Up @@ -189,7 +190,7 @@ def _eval_by_dataset(


def _eval_by_batch(
solver: "solver.Solver", epoch_id: int, log_freq: int
solver: "solver.Solver", epoch_id: Optional[int], log_freq: int
) -> Tuple[float, Dict[str, Dict[str, float]]]:
"""Evaluate with computing metric by batch, which is memory-efficient.

Expand All @@ -199,7 +200,7 @@ def _eval_by_batch(

Args:
solver (solver.Solver): Main Solver.
epoch_id (int): Epoch id.
epoch_id (Optional[int]): Epoch id.
log_freq (int): Log evaluation information every `log_freq` steps.

Returns:
Expand Down Expand Up @@ -303,13 +304,13 @@ def _eval_by_batch(


def eval_func(
solver: "solver.Solver", epoch_id: int, log_freq: int
solver: "solver.Solver", epoch_id: Optional[int], log_freq: int
) -> Tuple[float, Dict[str, Dict[str, float]]]:
"""Evaluation function.

Args:
solver (solver.Solver): Main Solver.
epoch_id (int): Epoch id.
epoch_id (Optional[int]): Epoch id.
log_freq (int): Log evaluation information every `log_freq` steps.

Returns:
Expand Down
16 changes: 11 additions & 5 deletions ppsci/solver/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,17 @@ def log_eval_info(

epoch_width = len(str(solver.epochs))
iters_width = len(str(iters_per_epoch))
logger.info(
f"[Eval][Epoch {epoch_id:>{epoch_width}}/{solver.epochs}]"
f"[Iter {iter_id:>{iters_width}}/{iters_per_epoch}] "
f"{metric_msg}, {time_msg}, {ips_msg}, {eta_msg}"
)
if isinstance(epoch_id, int):
logger.info(
f"[Eval][Epoch {epoch_id:>{epoch_width}}/{solver.epochs}]"
f"[Iter {iter_id:>{iters_width}}/{iters_per_epoch}] "
f"{metric_msg}, {time_msg}, {ips_msg}, {eta_msg}"
)
else:
logger.info(
f"[Eval][Iter {iter_id:>{iters_width}}/{iters_per_epoch}] "
f"{metric_msg}, {time_msg}, {ips_msg}, {eta_msg}"
)

# logger.scalar(
# {
Expand Down
21 changes: 15 additions & 6 deletions ppsci/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,11 +619,13 @@ def finetune(self, pretrained_model_path: str) -> None:
self.train()

@misc.run_on_eval_mode
def eval(self, epoch_id: int = 0) -> Tuple[float, Dict[str, Dict[str, float]]]:
def eval(
self, epoch_id: Optional[int] = None
) -> Tuple[float, Dict[str, Dict[str, float]]]:
"""Evaluation.

Args:
epoch_id (int, optional): Epoch id. Defaults to 0.
epoch_id (Optional[int]): Epoch id. Defaults to None.

Returns:
Tuple[float, Dict[str, Dict[str, float]]]: A targe metric value(float) and
Expand All @@ -636,23 +638,30 @@ def eval(self, epoch_id: int = 0) -> Tuple[float, Dict[str, Dict[str, float]]]:
metric_msg = ", ".join(
[self.eval_output_info[key].avg_info for key in self.eval_output_info]
)
logger.info(f"[Eval][Epoch {epoch_id}][Avg] {metric_msg}")

if isinstance(epoch_id, int):
logger.info(f"[Eval][Epoch {epoch_id}][Avg] {metric_msg}")
else:
logger.info(f"[Eval][Avg] {metric_msg}")
self.eval_output_info.clear()

return result

@misc.run_on_eval_mode
def visualize(self, epoch_id: int = 0):
def visualize(self, epoch_id: Optional[int] = None):
"""Visualization.

Args:
epoch_id (int, optional): Epoch id. Defaults to 0.
epoch_id (Optional[int]): Epoch id. Defaults to None.
"""
# set visualize func
self.visu_func = ppsci.solver.visu.visualize_func

self.visu_func(self, epoch_id)
logger.info(f"[Visualize][Epoch {epoch_id}] Finish visualization")
if isinstance(epoch_id, int):
logger.info(f"[Visualize][Epoch {epoch_id}] Finish visualization")
else:
logger.info("[Visualize] Finish visualization")

@misc.run_on_eval_mode
def predict(
Expand Down
9 changes: 6 additions & 3 deletions ppsci/solver/visu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import os.path as osp
from typing import TYPE_CHECKING
from typing import Optional

import paddle

Expand All @@ -26,12 +27,12 @@
from ppsci import solver


def visualize_func(solver: "solver.Solver", epoch_id: int):
def visualize_func(solver: "solver.Solver", epoch_id: Optional[int]):
"""Visualization program.

Args:
solver (solver.Solver): Main Solver.
epoch_id (int): Epoch id.
epoch_id (Optional[int]): Epoch id.
"""
for _, _visualizer in solver.visualizer.items():
all_input = misc.Prettydefaultdict(list)
Expand Down Expand Up @@ -87,7 +88,9 @@ def visualize_func(solver: "solver.Solver", epoch_id: int):
# save visualization
with misc.RankZeroOnly(solver.rank) as is_master:
if is_master:
visual_dir = osp.join(solver.output_dir, "visual", f"epoch_{epoch_id}")
visual_dir = osp.join(solver.output_dir, "visual")
if epoch_id:
visual_dir = osp.join(visual_dir, f"epoch_{epoch_id}")
os.makedirs(visual_dir, exist_ok=True)
_visualizer.save(
osp.join(visual_dir, _visualizer.prefix),
Expand Down
1 change: 1 addition & 0 deletions ppsci/utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def _download(url, path, md5sum=None):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
logger.message(f"Finished downloading pretrained model and saved to {fullname}")

return fullname

Expand Down
2 changes: 1 addition & 1 deletion ppsci/utils/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, model: nn.Layer, decay: Optional[float] = None):
self.model = model # As a quick reference to online model
self.decay = decay

self.params_shadow: Dict[str, paddle.Tensor] = {} # ema param or bufer
self.params_shadow: Dict[str, paddle.Tensor] = {} # ema param or buffer
self.params_backup: Dict[str, paddle.Tensor] = {} # used for apply and restore
for name, param_or_buffer in itertools.chain(
self.model.named_parameters(), self.model.named_buffers()
Expand Down