Skip to content

Commit 1e1958c

Browse files
add Solver.finetune() (#755)
1 parent 3a5135a commit 1e1958c

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

docs/zh/user_guide.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,11 @@ solver = ppsci.solver.Solver(
255255
256256
### 1.4 迁移学习
257257
258-
迁移学习是一种广泛使用、低成本提高模型精度的训练方式。在 PaddleScience 中,可以通过在 `model` 实例化完毕之后,手动为其载入预训练模型权重;也可以在 `Solver` 实例化时指定 `pretrained_model_path` 自动载入预训练模型权重,两种方式都可以进行迁移学习
258+
迁移学习是一种广泛使用、低成本提高模型精度的训练方式。在 PaddleScience 中,可以通过在 `model` 实例化完毕之后,手动为其载入预训练模型权重后开始微调训练;也可以调用 `Solver.finetune` 接口并指定 `pretrained_model_path` 参数,自动载入预训练模型权重并开始微调训练
259259
260260
=== "手动载入预训练模型"
261261
262-
``` py hl_lines="8"
262+
``` py hl_lines="8 12"
263263
import ppsci
264264
from ppsci.utils import save_load
265265
@@ -268,24 +268,29 @@ solver = ppsci.solver.Solver(
268268
269269
model = ...
270270
save_load.load_pretrain(model, "/path/to/pretrain")
271+
solver = ppsci.solver.Solver(
272+
...,
273+
)
274+
solver.train()
271275
```
272276
273-
=== "指定 `pretrained_model_path` 自动载入预训练模型"
277+
=== "调用 `Solver.finetune` 接口"
274278
275-
``` py hl_lines="9"
279+
``` py hl_lines="11"
276280
import ppsci
277281
282+
278283
...
279284
...
280285
281286
model = ...
282287
solver = ppsci.solver.Solver(
283288
...,
284-
pretrained_model_path="/path/to/pretrain",
285289
)
290+
solver.finetune(pretrained_model_path="/path/to/pretrain")
286291
```
287292
288-
!!! info "迁移学习建议"
293+
!!! tip "迁移学习建议"
289294
290295
在迁移学习时,相对于完全随机初始化的参数而言,载入的预训练模型权重参数是一个较好的初始化状态,因此不需要使用太大的学习率,而可以将学习率适当调小 2~10 倍以获得更稳定的训练过程和更好的精度。
291296

ppsci/solver/solver.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def convert_expr(
398398
# set up benchmark flag, will print memory stat if enabled
399399
self.benchmark_flag: bool = os.getenv("BENCHMARK_ROOT", None) is not None
400400

401-
def train(self):
401+
def train(self) -> None:
402402
"""Training."""
403403
self.global_step = self.best_metric["epoch"] * self.iters_per_epoch
404404
start_epoch = self.best_metric["epoch"] + 1
@@ -471,6 +471,18 @@ def train(self):
471471
print_log=(epoch_id == start_epoch),
472472
)
473473

474+
def finetune(self, pretrained_model_path: str) -> None:
475+
"""Finetune model based on given pretrained model.
476+
477+
Args:
478+
pretrained_model_path (str): Pretrained model path.
479+
"""
480+
# load pretrained model
481+
save_load.load_pretrain(self.model, pretrained_model_path, self.equation)
482+
483+
# call train program
484+
self.train()
485+
474486
@misc.run_on_eval_mode
475487
def eval(self, epoch_id: int = 0) -> Tuple[float, Dict[str, Dict[str, float]]]:
476488
"""Evaluation.

0 commit comments

Comments
 (0)