@@ -254,14 +254,80 @@ def u_solution_func(in_) -> np.ndarray:
254254 plt .savefig (osp .join (cfg .output_dir , "./Volterra_IDE.png" ), dpi = 200 )
255255
256256
257+ def export (cfg : DictConfig ):
258+ # set model
259+ model = ppsci .arch .MLP (** cfg .MODEL )
260+
261+ # initialize solver
262+ solver = ppsci .solver .Solver (
263+ model ,
264+ pretrained_model_path = cfg .INFER .pretrained_model_path ,
265+ )
266+ # export model
267+ from paddle .static import InputSpec
268+
269+ input_spec = [
270+ {
271+ key : InputSpec ([None , 1 ], "float32" , name = key )
272+ for key in cfg .MODEL .input_keys
273+ },
274+ ]
275+ solver .export (input_spec , cfg .INFER .export_path )
276+
277+
278+ def inference (cfg : DictConfig ):
279+ from deploy .python_infer import pinn_predictor
280+
281+ predictor = pinn_predictor .PINNPredictor (cfg )
282+
283+ # set geometry
284+ geom = {"timedomain" : ppsci .geometry .TimeDomain (* cfg .BOUNDS )}
285+
286+ input_data = geom ["timedomain" ].uniform_points (cfg .EVAL .npoint_eval )
287+ input_dict = {"x" : input_data }
288+
289+ output_dict = predictor .predict (
290+ {key : input_dict [key ] for key in cfg .MODEL .input_keys }, cfg .INFER .batch_size
291+ )
292+
293+ # mapping data to cfg.INFER.output_keys
294+ output_dict = {
295+ store_key : output_dict [infer_key ]
296+ for store_key , infer_key in zip (cfg .MODEL .output_keys , output_dict .keys ())
297+ }
298+
299+ def u_solution_func (in_ ) -> np .ndarray :
300+ if isinstance (in_ ["x" ], paddle .Tensor ):
301+ return paddle .exp (- in_ ["x" ]) * paddle .cosh (in_ ["x" ])
302+ return np .exp (- in_ ["x" ]) * np .cosh (in_ ["x" ])
303+
304+ label_data = u_solution_func ({"x" : input_data })
305+ output_data = output_dict ["u" ]
306+
307+ # save result
308+ plt .plot (input_data , label_data , "-" , label = r"$u(t)$" )
309+ plt .plot (input_data , output_data , "o" , label = r"$\hat{u}(t)$" , markersize = 4.0 )
310+ plt .legend ()
311+ plt .xlabel (r"$t$" )
312+ plt .ylabel (r"$u$" )
313+ plt .title (r"$u-t$" )
314+ plt .savefig ("./Volterra_IDE_pred.png" , dpi = 200 )
315+
316+
257317@hydra .main (version_base = None , config_path = "./conf" , config_name = "volterra_ide.yaml" )
258318def main (cfg : DictConfig ):
259319 if cfg .mode == "train" :
260320 train (cfg )
261321 elif cfg .mode == "eval" :
262322 evaluate (cfg )
323+ elif cfg .mode == "export" :
324+ export (cfg )
325+ elif cfg .mode == "infer" :
326+ inference (cfg )
263327 else :
264- raise ValueError (f"cfg.mode should in ['train', 'eval'], but got '{ cfg .mode } '" )
328+ raise ValueError (
329+ f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{ cfg .mode } '"
330+ )
265331
266332
267333if __name__ == "__main__" :
0 commit comments