3131 from paddle import optimizer
3232
3333 from ppsci import equation
34+ from ppsci .loss import mtl
3435 from ppsci .utils import ema
3536
3637
4243
4344
4445def _load_pretrain_from_path (
45- path : str , model : nn .Layer , equation : Optional [Dict [str , equation .PDE ]] = None
46+ path : str ,
47+ model : nn .Layer ,
48+ equation : Optional [Dict [str , equation .PDE ]] = None ,
49+ loss_aggregator : Optional [mtl .LossAggregator ] = None ,
4650):
4751 """Load pretrained model from given path.
4852
@@ -77,9 +81,26 @@ def _load_pretrain_from_path(
7781 f"Finish loading pretrained equation parameters from: { path } .pdeqn"
7882 )
7983
84+ if loss_aggregator is not None :
85+ if not os .path .exists (f"{ path } .pdagg" ):
86+ if loss_aggregator .should_persist :
87+ logger .warning (
88+ f"Given loss_aggregator({ type (loss_aggregator )} ) has persistable"
89+ f"parameters or buffers, but { path } .pdagg not found."
90+ )
91+ else :
92+ aggregator_dict = paddle .load (f"{ path } .pdagg" )
93+ loss_aggregator .set_state_dict (aggregator_dict )
94+ logger .message (
95+ f"Finish loading pretrained equation parameters from: { path } .pdagg"
96+ )
97+
8098
8199def load_pretrain (
82- model : nn .Layer , path : str , equation : Optional [Dict [str , equation .PDE ]] = None
100+ model : nn .Layer ,
101+ path : str ,
102+ equation : Optional [Dict [str , equation .PDE ]] = None ,
103+ loss_aggregator : Optional [mtl .LossAggregator ] = None ,
83104):
84105 """
85106 Load pretrained model from given path or url.
@@ -121,7 +142,7 @@ def is_url_accessible(url: str):
121142 # remove ".pdparams" in suffix of path for convenient
122143 if path .endswith (".pdparams" ):
123144 path = path [:- 9 ]
124- _load_pretrain_from_path (path , model , equation )
145+ _load_pretrain_from_path (path , model , equation , loss_aggregator )
125146
126147
127148def load_checkpoint (
@@ -131,6 +152,7 @@ def load_checkpoint(
131152 grad_scaler : Optional [amp .GradScaler ] = None ,
132153 equation : Optional [Dict [str , equation .PDE ]] = None ,
133154 ema_model : Optional [ema .AveragedModel ] = None ,
155+ aggregator : Optional [mtl .LossAggregator ] = None ,
134156) -> Dict [str , Any ]:
135157 """Load from checkpoint.
136158
@@ -141,6 +163,7 @@ def load_checkpoint(
141163 grad_scaler (Optional[amp.GradScaler]): GradScaler for AMP. Defaults to None.
142164 equation (Optional[Dict[str, equation.PDE]]): Equations. Defaults to None.
143165 ema_model: Optional[ema.AveragedModel]: Average model. Defaults to None.
166+ aggregator: Optional[mtl.LossAggregator]: Loss aggregator. Defaults to None.
144167
145168 Returns:
146169 Dict[str, Any]: Loaded metric information.
@@ -189,6 +212,10 @@ def load_checkpoint(
189212 avg_param_dict = paddle .load (f"{ path } _ema.pdparams" )
190213 ema_model .set_state_dict (avg_param_dict )
191214
215+ if aggregator is not None :
216+ aggregator_dict = paddle .load (f"{ path } .pdagg" )
217+ aggregator .set_state_dict (aggregator_dict )
218+
192219 logger .message (f"Finish loading checkpoint from { path } " )
193220 return metric_dict
194221
@@ -203,6 +230,7 @@ def save_checkpoint(
203230 equation : Optional [Dict [str , equation .PDE ]] = None ,
204231 print_log : bool = True ,
205232 ema_model : Optional [ema .AveragedModel ] = None ,
233+ aggregator : Optional [mtl .LossAggregator ] = None ,
206234):
207235 """
208236 Save checkpoint, including model params, optimizer params, metric information.
@@ -219,6 +247,7 @@ def save_checkpoint(
219247 keeping log tidy without duplicate 'Finish saving checkpoint ...' log strings.
220248 Defaults to True.
221249 ema_model: Optional[ema.AveragedModel]: Average model. Defaults to None.
250+ aggregator: Optional[mtl.LossAggregator]: Loss aggregator. Defaults to None.
222251
223252 Examples:
224253 >>> import ppsci
@@ -258,6 +287,9 @@ def save_checkpoint(
258287 if ema_model :
259288 paddle .save (ema_model .state_dict (), f"{ ckpt_path } _ema.pdparams" )
260289
290+ if aggregator and aggregator .should_persist :
291+ paddle .save (aggregator .state_dict (), f"{ ckpt_path } .pdagg" )
292+
261293 if print_log :
262294 log_str = f"Finish saving checkpoint to: { ckpt_path } "
263295 if prefix == "latest" :
0 commit comments