Skip to content

Commit

Permalink
add example for AMGNet
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Oct 18, 2023
1 parent 7eb4aab commit 876ef4d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
8 changes: 6 additions & 2 deletions ppsci/arch/amgnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,8 @@ class AMGNet(nn.Layer):
Code reference: https://github.com/baoshiaijhin/amgnet
Args:
input_keys (Tuple[str, ...]): Name of input keys, such as ("x", "y", "z").
output_keys (Tuple[str, ...]): Name of output keys, such as ("u", "v", "w").
input_keys (Tuple[str, ...]): Name of input keys, such as ("input", ).
output_keys (Tuple[str, ...]): Name of output keys, such as ("pred", ).
input_dim (int): Number of input dimension.
output_dim (int): Number of output dimension.
latent_dim (int): Number of hidden(feature) dimension.
Expand All @@ -576,6 +576,10 @@ class AMGNet(nn.Layer):
message_passing_steps (int): Message passing steps in graph.
speed (str): Whether use vanilla method or fast method for graph_connectivity
computation.
Examples:
>>> import ppsci
>>> model = ppsci.arch.AMGNet(("input", ), ("pred", ), 5, 3, 64, 2)
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion ppsci/data/process/batch_transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def default_collate_fn(batch: List[Any]) -> Any:
return graph

raise TypeError(
"batch data can only contains: Tensor, numpy.ndarray, "
"batch data can only contains: paddle.Tensor, numpy.ndarray, "
f"dict, list, number, None, pgl.Graph, but got {type(sample)}"
)

Expand Down
2 changes: 1 addition & 1 deletion ppsci/solver/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _get_datset_length(
def _eval_by_dataset(
solver: "solver.Solver", epoch_id: int, log_freq: int
) -> Tuple[float, Dict[str, Dict[str, float]]]:
"""Evaluate with computing metric on total samples(default progress).
"""Evaluate with computing metric on total samples(default process).
Args:
solver (solver.Solver): Main Solver.
Expand Down

0 comments on commit 876ef4d

Please sign in to comment.