diff --git a/ppsci/arch/amgnet.py b/ppsci/arch/amgnet.py index 5a3eb45cd5..a0bf534dba 100644 --- a/ppsci/arch/amgnet.py +++ b/ppsci/arch/amgnet.py @@ -565,8 +565,8 @@ class AMGNet(nn.Layer): Code reference: https://github.com/baoshiaijhin/amgnet Args: - input_keys (Tuple[str, ...]): Name of input keys, such as ("x", "y", "z"). - output_keys (Tuple[str, ...]): Name of output keys, such as ("u", "v", "w"). + input_keys (Tuple[str, ...]): Name of input keys, such as ("input", ). + output_keys (Tuple[str, ...]): Name of output keys, such as ("pred", ). input_dim (int): Number of input dimension. output_dim (int): Number of output dimension. latent_dim (int): Number of hidden(feature) dimension. @@ -576,6 +576,10 @@ class AMGNet(nn.Layer): message_passing_steps (int): Message passing steps in graph. speed (str): Whether use vanilla method or fast method for graph_connectivity computation. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.AMGNet(("input", ), ("pred", ), 5, 3, 64, 2) """ def __init__( diff --git a/ppsci/data/process/batch_transform/__init__.py b/ppsci/data/process/batch_transform/__init__.py index 2f50850d21..5eceb2bb54 100644 --- a/ppsci/data/process/batch_transform/__init__.py +++ b/ppsci/data/process/batch_transform/__init__.py @@ -76,7 +76,7 @@ def default_collate_fn(batch: List[Any]) -> Any: return graph raise TypeError( - "batch data can only contains: Tensor, numpy.ndarray, " + "batch data can only contains: paddle.Tensor, numpy.ndarray, " f"dict, list, number, None, pgl.Graph, but got {type(sample)}" )