@@ -46,7 +46,6 @@ def _load_pretrain_from_path(
4646 path : str ,
4747 model : nn .Layer ,
4848 equation : Optional [Dict [str , equation .PDE ]] = None ,
49- loss_aggregator : Optional [mtl .LossAggregator ] = None ,
5049):
5150 """Load pretrained model from given path.
5251
@@ -81,26 +80,11 @@ def _load_pretrain_from_path(
8180 f"Finish loading pretrained equation parameters from: { path } .pdeqn"
8281 )
8382
84- if loss_aggregator is not None :
85- if not os .path .exists (f"{ path } .pdagg" ):
86- if loss_aggregator .should_persist :
87- logger .warning (
88- f"Given loss_aggregator({ type (loss_aggregator )} ) has persistable"
89- f"parameters or buffers, but { path } .pdagg not found."
90- )
91- else :
92- aggregator_dict = paddle .load (f"{ path } .pdagg" )
93- loss_aggregator .set_state_dict (aggregator_dict )
94- logger .message (
95- f"Finish loading pretrained equation parameters from: { path } .pdagg"
96- )
97-
9883
9984def load_pretrain (
10085 model : nn .Layer ,
10186 path : str ,
10287 equation : Optional [Dict [str , equation .PDE ]] = None ,
103- loss_aggregator : Optional [mtl .LossAggregator ] = None ,
10488):
10589 """
10690 Load pretrained model from given path or url.
@@ -142,7 +126,7 @@ def is_url_accessible(url: str):
142126 # remove ".pdparams" in suffix of path for convenient
143127 if path .endswith (".pdparams" ):
144128 path = path [:- 9 ]
145- _load_pretrain_from_path (path , model , equation , loss_aggregator )
129+ _load_pretrain_from_path (path , model , equation )
146130
147131
148132def load_checkpoint (
0 commit comments