From b32f3819aff4f765c5de049f6810c77abf51236e Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Thu, 18 Jan 2024 20:36:28 +0800 Subject: [PATCH 1/4] update docstring --- ppsci/utils/save_load.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/ppsci/utils/save_load.py b/ppsci/utils/save_load.py index d1616725a..888f3314f 100644 --- a/ppsci/utils/save_load.py +++ b/ppsci/utils/save_load.py @@ -80,13 +80,24 @@ def _load_pretrain_from_path( def load_pretrain( model: nn.Layer, path: str, equation: Optional[Dict[str, equation.PDE]] = None ): - """Load pretrained model from given path or url. + """ + Load pretrained model from given path or url. Args: model (nn.Layer): Model with parameters. path (str): File path or url of pretrained model, i.e. `/path/to/model.pdparams` or `http://xxx.com/model.pdparams`. equation (Optional[Dict[str, equation.PDE]]): Equations. Defaults to None. + + Examples: + >>> import ppsci + >>> from ppsci.utils import save_load + >>> model = ppsci.arch.MLP(("x", "y"), ("u", "v", "p"), 9, 50, "tanh") + >>> save_load.load_pretrain( + ... model = model, + ... path = "path/to/pretrain_model" + ... ) # doctest: +SKIP + """ if path.startswith("http"): path = download.get_weights_path_from_url(path) @@ -159,7 +170,8 @@ def save_checkpoint( equation: Optional[Dict[str, equation.PDE]] = None, print_log: bool = True, ): - """Save checkpoint, including model params, optimizer params, metric information. + """ + Save checkpoint, including model params, optimizer params, metric information. Args: model (nn.Layer): Model with parameters. @@ -172,6 +184,14 @@ def save_checkpoint( print_log (bool, optional): Whether print saving log information, mainly for keeping log tidy without duplicate 'Finish saving checkpoint ...' log strings. Defaults to True. + + Examples: + >>> import ppsci + >>> import paddle + >>> from ppsci.utils import save_load + >>> model = ppsci.arch.MLP(("x", "y", "z"), ("u", "v", "w"), 5, 64, "tanh") + >>> optimizer = ppsci.optimizer.Adam(0.001)(model) + >>> save_load.save_checkpoint(model, optimizer, {"RMSE": 0.1}, output_dir="path/to/output/dir") # doctest: +SKIP """ if paddle.distributed.get_rank() != 0: return From 50c74f921c15c190bd99321743b01534155ae6bb Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Fri, 19 Jan 2024 12:22:29 +0800 Subject: [PATCH 2/4] remove spaces in args --- ppsci/utils/save_load.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ppsci/utils/save_load.py b/ppsci/utils/save_load.py index 888f3314f..9d454b315 100644 --- a/ppsci/utils/save_load.py +++ b/ppsci/utils/save_load.py @@ -94,8 +94,8 @@ def load_pretrain( >>> from ppsci.utils import save_load >>> model = ppsci.arch.MLP(("x", "y"), ("u", "v", "p"), 9, 50, "tanh") >>> save_load.load_pretrain( - ... model = model, - ... path = "path/to/pretrain_model" + ... model=model, + ... path="path/to/pretrain_model" ... ) # doctest: +SKIP """ From 3df7ea6fa53a5a8fd614d98d3baa0a2d7f8feea8 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Fri, 19 Jan 2024 12:31:25 +0800 Subject: [PATCH 3/4] update --- ppsci/utils/save_load.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ppsci/utils/save_load.py b/ppsci/utils/save_load.py index 9d454b315..b59a087eb 100644 --- a/ppsci/utils/save_load.py +++ b/ppsci/utils/save_load.py @@ -95,8 +95,7 @@ def load_pretrain( >>> model = ppsci.arch.MLP(("x", "y"), ("u", "v", "p"), 9, 50, "tanh") >>> save_load.load_pretrain( ... model=model, - ... path="path/to/pretrain_model" - ... ) # doctest: +SKIP + ... path="path/to/pretrain_model") # doctest: +SKIP """ if path.startswith("http"): From c674103af2552fa39fda670eff894c94398ab956 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Fri, 19 Jan 2024 12:37:11 +0800 Subject: [PATCH 4/4] remove extra blank line --- ppsci/utils/save_load.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ppsci/utils/save_load.py b/ppsci/utils/save_load.py index b59a087eb..b882c515d 100644 --- a/ppsci/utils/save_load.py +++ b/ppsci/utils/save_load.py @@ -96,7 +96,6 @@ def load_pretrain( >>> save_load.load_pretrain( ... model=model, ... path="path/to/pretrain_model") # doctest: +SKIP - """ if path.startswith("http"): path = download.get_weights_path_from_url(path)