Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PPSCI Doc No.104-105】 #759

Merged
merged 5 commits into from
Jan 19, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions ppsci/utils/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,24 @@ def _load_pretrain_from_path(
def load_pretrain(
model: nn.Layer, path: str, equation: Optional[Dict[str, equation.PDE]] = None
):
"""Load pretrained model from given path or url.
"""
Load pretrained model from given path or url.

Args:
model (nn.Layer): Model with parameters.
path (str): File path or url of pretrained model, i.e. `/path/to/model.pdparams`
or `http://xxx.com/model.pdparams`.
equation (Optional[Dict[str, equation.PDE]]): Equations. Defaults to None.

Examples:
>>> import ppsci
>>> from ppsci.utils import save_load
>>> model = ppsci.arch.MLP(("x", "y"), ("u", "v", "p"), 9, 50, "tanh")
>>> save_load.load_pretrain(
... model = model,
... path = "path/to/pretrain_model"
ooooo-create marked this conversation as resolved.
Show resolved Hide resolved
... ) # doctest: +SKIP

ooooo-create marked this conversation as resolved.
Show resolved Hide resolved
"""
if path.startswith("http"):
path = download.get_weights_path_from_url(path)
Expand Down Expand Up @@ -159,7 +170,8 @@ def save_checkpoint(
equation: Optional[Dict[str, equation.PDE]] = None,
print_log: bool = True,
):
"""Save checkpoint, including model params, optimizer params, metric information.
"""
Save checkpoint, including model params, optimizer params, metric information.

Args:
model (nn.Layer): Model with parameters.
Expand All @@ -172,6 +184,14 @@ def save_checkpoint(
print_log (bool, optional): Whether print saving log information, mainly for
keeping log tidy without duplicate 'Finish saving checkpoint ...' log strings.
Defaults to True.

Examples:
>>> import ppsci
>>> import paddle
>>> from ppsci.utils import save_load
>>> model = ppsci.arch.MLP(("x", "y", "z"), ("u", "v", "w"), 5, 64, "tanh")
>>> optimizer = ppsci.optimizer.Adam(0.001)(model)
>>> save_load.save_checkpoint(model, optimizer, {"RMSE": 0.1}, output_dir="path/to/output/dir") # doctest: +SKIP
"""
if paddle.distributed.get_rank() != 0:
return
Expand Down