From 876ef4dad58d1ee62768ad510667274c0af07d70 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 18 Oct 2023 05:38:15 +0000 Subject: [PATCH] add example for AMGNet --- ppsci/arch/amgnet.py | 8 ++++++-- ppsci/data/process/batch_transform/__init__.py | 2 +- ppsci/solver/eval.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ppsci/arch/amgnet.py b/ppsci/arch/amgnet.py index 5a3eb45cd..a0bf534db 100644 --- a/ppsci/arch/amgnet.py +++ b/ppsci/arch/amgnet.py @@ -565,8 +565,8 @@ class AMGNet(nn.Layer): Code reference: https://github.com/baoshiaijhin/amgnet Args: - input_keys (Tuple[str, ...]): Name of input keys, such as ("x", "y", "z"). - output_keys (Tuple[str, ...]): Name of output keys, such as ("u", "v", "w"). + input_keys (Tuple[str, ...]): Name of input keys, such as ("input", ). + output_keys (Tuple[str, ...]): Name of output keys, such as ("pred", ). input_dim (int): Number of input dimension. output_dim (int): Number of output dimension. latent_dim (int): Number of hidden(feature) dimension. @@ -576,6 +576,10 @@ class AMGNet(nn.Layer): message_passing_steps (int): Message passing steps in graph. speed (str): Whether use vanilla method or fast method for graph_connectivity computation. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.AMGNet(("input", ), ("pred", ), 5, 3, 64, 2) """ def __init__( diff --git a/ppsci/data/process/batch_transform/__init__.py b/ppsci/data/process/batch_transform/__init__.py index 2f50850d2..5eceb2bb5 100644 --- a/ppsci/data/process/batch_transform/__init__.py +++ b/ppsci/data/process/batch_transform/__init__.py @@ -76,7 +76,7 @@ def default_collate_fn(batch: List[Any]) -> Any: return graph raise TypeError( - "batch data can only contains: Tensor, numpy.ndarray, " + "batch data can only contains: paddle.Tensor, numpy.ndarray, " f"dict, list, number, None, pgl.Graph, but got {type(sample)}" ) diff --git a/ppsci/solver/eval.py b/ppsci/solver/eval.py index fb4a570e6..09a97bba3 100644 --- a/ppsci/solver/eval.py +++ b/ppsci/solver/eval.py @@ -63,7 +63,7 @@ def _get_datset_length( def _eval_by_dataset( solver: "solver.Solver", epoch_id: int, log_freq: int ) -> Tuple[float, Dict[str, Dict[str, float]]]: - """Evaluate with computing metric on total samples(default progress). + """Evaluate with computing metric on total samples(default process). Args: solver (solver.Solver): Main Solver.