Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions docs/zh/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,11 @@ solver = ppsci.solver.Solver(

### 1.4 迁移学习

迁移学习是一种广泛使用、低成本提高模型精度的训练方式。在 PaddleScience 中,可以通过在 `model` 实例化完毕之后,手动为其载入预训练模型权重;也可以在 `Solver` 实例化时指定 `pretrained_model_path` 自动载入预训练模型权重,两种方式都可以进行迁移学习
迁移学习是一种广泛使用、低成本提高模型精度的训练方式。在 PaddleScience 中,可以通过在 `model` 实例化完毕之后,手动为其载入预训练模型权重后开始微调训练;也可以调用 `Solver.finetune` 接口并指定 `pretrained_model_path` 参数,自动载入预训练模型权重并开始微调训练

=== "手动载入预训练模型"

``` py hl_lines="8"
``` py hl_lines="8 12"
import ppsci
from ppsci.utils import save_load

Expand All @@ -268,24 +268,29 @@ solver = ppsci.solver.Solver(

model = ...
save_load.load_pretrain(model, "/path/to/pretrain")
solver = ppsci.solver.Solver(
...,
)
solver.train()
```

=== "指定 `pretrained_model_path` 自动载入预训练模型"
=== "调用 `Solver.finetune` 接口"

``` py hl_lines="9"
``` py hl_lines="11"
import ppsci


...
...

model = ...
solver = ppsci.solver.Solver(
...,
pretrained_model_path="/path/to/pretrain",
)
solver.finetune(pretrained_model_path="/path/to/pretrain")
```

!!! info "迁移学习建议"
!!! tip "迁移学习建议"

在迁移学习时,相对于完全随机初始化的参数而言,载入的预训练模型权重参数是一个较好的初始化状态,因此不需要使用太大的学习率,而可以将学习率适当调小 2~10 倍以获得更稳定的训练过程和更好的精度。

Expand Down
14 changes: 13 additions & 1 deletion ppsci/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def convert_expr(
# set up benchmark flag, will print memory stat if enabled
self.benchmark_flag: bool = os.getenv("BENCHMARK_ROOT", None) is not None

def train(self):
def train(self) -> None:
"""Training."""
self.global_step = self.best_metric["epoch"] * self.iters_per_epoch
start_epoch = self.best_metric["epoch"] + 1
Expand Down Expand Up @@ -471,6 +471,18 @@ def train(self):
print_log=(epoch_id == start_epoch),
)

def finetune(self, pretrained_model_path: str) -> None:
"""Finetune model based on given pretrained model.

Args:
pretrained_model_path (str): Pretrained model path.
"""
# load pretrained model
save_load.load_pretrain(self.model, pretrained_model_path, self.equation)

# call train program
self.train()

@misc.run_on_eval_mode
def eval(self, epoch_id: int = 0) -> Tuple[float, Dict[str, Dict[str, float]]]:
"""Evaluation.
Expand Down