Skip to content

Commit 932fe0d

Browse files
[Fix] Fix aggregator error when load ckpt (#1038)
* fix aggregator error when load ckpt * add unitest
1 parent 7371f85 commit 932fe0d

File tree

3 files changed

+108
-9
lines changed

3 files changed

+108
-9
lines changed

ppsci/solver/solver.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,15 @@ def __init__(
332332
"metric": float("inf"),
333333
"epoch": 0,
334334
}
335+
336+
# use loss aggregator, use Sum if None
337+
if isinstance(loss_aggregator, (mtl.AGDA, mtl.PCGrad)) and self.use_amp:
338+
raise ValueError(
339+
"Auto Mix Precision do not support AGDA, PCGrad loss aggregator yet, "
340+
"please set use_amp=False."
341+
)
342+
self.loss_aggregator = loss_aggregator or mtl.Sum()
343+
335344
# load model checkpoint, usually used for resume training
336345
if not cfg:
337346
self.checkpoint_path = checkpoint_path
@@ -478,14 +487,6 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel:
478487
jit.enable_to_static(to_static)
479488
logger.message(f"Set to_static={to_static} for computational optimization.")
480489

481-
# use loss aggregator, use Sum if None
482-
if isinstance(loss_aggregator, (mtl.AGDA, mtl.PCGrad)) and self.use_amp:
483-
raise ValueError(
484-
"Auto Mix Precision do not support AGDA, PCGrad loss aggregator yet, "
485-
"please set use_amp=False."
486-
)
487-
self.loss_aggregator = loss_aggregator or mtl.Sum()
488-
489490
# convert sympy to callable object if exist
490491
extra_parameters = []
491492
if self.equation:

ppsci/utils/save_load.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def load_checkpoint(
196196
avg_param_dict = paddle.load(f"{path}_ema.pdparams")
197197
ema_model.set_state_dict(avg_param_dict)
198198

199-
if aggregator is not None:
199+
if aggregator is not None and aggregator.should_persist:
200200
aggregator_dict = paddle.load(f"{path}.pdagg")
201201
aggregator.set_state_dict(aggregator_dict)
202202

test/loss/aggregator.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import pytest
2+
3+
import ppsci
4+
from ppsci import arch
5+
from ppsci.loss import mtl
6+
7+
__all__ = []
8+
9+
10+
class AggregatorTest:
11+
def __init__(self):
12+
self.model = arch.MLP(
13+
("x", "y"),
14+
("u", "v"),
15+
3,
16+
16,
17+
)
18+
19+
def _check_agg_state_dict(self, agg):
20+
model_state = self.model.state_dict()
21+
agg_state = agg.state_dict()
22+
for k in agg_state:
23+
assert k not in model_state
24+
25+
def test_AGDA(self):
26+
aggregator = mtl.AGDA(self.model)
27+
assert aggregator.should_persist is False
28+
29+
def test_GradNorm(self):
30+
aggregator = mtl.GradNorm(self.model)
31+
assert aggregator.should_persist is True
32+
self._check_agg_state_dict(aggregator)
33+
34+
def test_LossAggregator(self):
35+
aggregator = mtl.AGDA(self.model)
36+
assert aggregator.should_persist is False
37+
38+
def test_PCGrad(self):
39+
aggregator = mtl.PCGrad(self.model)
40+
assert aggregator.should_persist is False
41+
42+
def test_Relobralo(self):
43+
aggregator = mtl.Relobralo(self.model)
44+
assert aggregator.should_persist is True
45+
self._check_agg_state_dict(aggregator)
46+
47+
def test_Sum(self):
48+
aggregator = mtl.Sum(self.model)
49+
assert aggregator.should_persist is False
50+
51+
def test_NTK(self):
52+
aggregator = mtl.NTK(self.model)
53+
assert aggregator.should_persist is True
54+
self._check_agg_state_dict(aggregator)
55+
56+
def test_restore_aggregator(self):
57+
model = ppsci.arch.MLP(
58+
["x", "y"],
59+
["u"],
60+
2,
61+
16,
62+
)
63+
opt = ppsci.optimizer.Adam(1e-3)(model)
64+
equation = ppsci.equation.Laplace(2)
65+
geom = ppsci.geometry.Rectangle([0, 0], [1, 1])
66+
BC = ppsci.constraint.BoundaryConstraint(
67+
equation.equations,
68+
{"laplace": 0.0},
69+
geom,
70+
{
71+
"dataset": "IterableNamedArrayDataset",
72+
"iters_per_epoch": 10,
73+
"batch_size": 16,
74+
},
75+
loss=ppsci.loss.MSELoss(),
76+
)
77+
solver = ppsci.solver.Solver(
78+
model,
79+
{"bound": BC},
80+
optimizer=opt,
81+
output_dir="./tmp",
82+
iters_per_epoch=10,
83+
epochs=2,
84+
)
85+
solver.train()
86+
solver = ppsci.solver.Solver(
87+
model,
88+
{"bound": BC},
89+
optimizer=opt,
90+
output_dir="./tmp",
91+
iters_per_epoch=10,
92+
epochs=2,
93+
checkpoint_path="./tmp/checkpoints/latest",
94+
)
95+
96+
97+
if __name__ == "__main__":
98+
pytest.main()

0 commit comments

Comments
 (0)