@@ -518,6 +518,112 @@ def evaluate(cfg: DictConfig):
518518 plt .savefig (osp .join (cfg .output_dir , f"shock_wave(Ma_{ cfg .MA :.3f} ).png" ))
519519
520520
521+ def export (cfg : DictConfig ):
522+ from paddle .static import InputSpec
523+
524+ # set models
525+ model = ppsci .arch .MLP (** cfg .MODEL )
526+ solver = ppsci .solver .Solver (
527+ model ,
528+ pretrained_model_path = cfg .INFER .pretrained_model_path ,
529+ )
530+
531+ # export models
532+ input_spec = [
533+ {key : InputSpec ([None , 1 ], "float32" , name = key ) for key in model .input_keys },
534+ ]
535+ solver .export (input_spec , cfg .INFER .export_path )
536+
537+
538+ def inference (cfg : DictConfig ):
539+ from deploy .python_infer import pinn_predictor
540+
541+ # set model predictor
542+ predictor = pinn_predictor .PINNPredictor (cfg )
543+
544+ # visualize prediction
545+ t = np .linspace (cfg .T , cfg .T , 1 , dtype = np .float32 )
546+ x = np .linspace (0.0 , cfg .Lx , cfg .Nd , dtype = np .float32 )
547+ y = np .linspace (0.0 , cfg .Ly , cfg .Nd , dtype = np .float32 )
548+ _ , x_grid , y_grid = np .meshgrid (t , x , y )
549+
550+ x_test = misc .cartesian_product (t , x , y )
551+ x_test_dict = misc .convert_to_dict (
552+ x_test ,
553+ cfg .MODEL .input_keys ,
554+ )
555+ output_dict = predictor .predict (
556+ x_test_dict ,
557+ cfg .INFER .batch_size ,
558+ )
559+
560+ # mapping data to cfg.MODEL.output_keys
561+ output_dict = {
562+ store_key : output_dict [infer_key ]
563+ for store_key , infer_key in zip (cfg .MODEL .output_keys , output_dict .keys ())
564+ }
565+
566+ u , v , p , rho = (
567+ output_dict ["u" ],
568+ output_dict ["v" ],
569+ output_dict ["p" ],
570+ output_dict ["rho" ],
571+ )
572+
573+ zero_mask = (
574+ (x_test [:, 1 ] - cfg .rx ) ** 2 + (x_test [:, 2 ] - cfg .ry ) ** 2
575+ ) < cfg .rd ** 2
576+ u [zero_mask ] = 0
577+ v [zero_mask ] = 0
578+ p [zero_mask ] = 0
579+ rho [zero_mask ] = 0
580+
581+ u = u .reshape (cfg .Nd , cfg .Nd )
582+ v = v .reshape (cfg .Nd , cfg .Nd )
583+ p = p .reshape (cfg .Nd , cfg .Nd )
584+ rho = rho .reshape (cfg .Nd , cfg .Nd )
585+
586+ fig , ax = plt .subplots (2 , 2 , sharex = True , sharey = True , figsize = (15 , 15 ))
587+
588+ plt .subplot (2 , 2 , 1 )
589+ plt .contourf (x_grid [:, 0 , :], y_grid [:, 0 , :], u * 241.315 , 60 )
590+ plt .title ("U m/s" )
591+ plt .xlabel ("x" )
592+ plt .ylabel ("y" )
593+ axe = plt .gca ()
594+ axe .set_aspect (1 )
595+ plt .colorbar ()
596+
597+ plt .subplot (2 , 2 , 2 )
598+ plt .contourf (x_grid [:, 0 , :], y_grid [:, 0 , :], v * 241.315 , 60 )
599+ plt .title ("V m/s" )
600+ plt .xlabel ("x" )
601+ plt .ylabel ("y" )
602+ axe = plt .gca ()
603+ axe .set_aspect (1 )
604+ plt .colorbar ()
605+
606+ plt .subplot (2 , 2 , 3 )
607+ plt .contourf (x_grid [:, 0 , :], y_grid [:, 0 , :], p * 33775 , 60 )
608+ plt .title ("P Pa" )
609+ plt .xlabel ("x" )
610+ plt .ylabel ("y" )
611+ axe = plt .gca ()
612+ axe .set_aspect (1 )
613+ plt .colorbar ()
614+
615+ plt .subplot (2 , 2 , 4 )
616+ plt .contourf (x_grid [:, 0 , :], y_grid [:, 0 , :], rho * 0.58 , 60 )
617+ plt .title ("Rho kg/m^3" )
618+ plt .xlabel ("x" )
619+ plt .ylabel ("y" )
620+ axe = plt .gca ()
621+ axe .set_aspect (1 )
622+ plt .colorbar ()
623+
624+ plt .savefig (osp .join (cfg .output_dir , f"shock_wave(Ma_{ cfg .MA :.3f} ).png" ))
625+
626+
521627@hydra .main (
522628 version_base = None , config_path = "./conf" , config_name = "shock_wave_Ma2.0.yaml"
523629)
@@ -526,8 +632,14 @@ def main(cfg: DictConfig):
526632 train (cfg )
527633 elif cfg .mode == "eval" :
528634 evaluate (cfg )
635+ elif cfg .mode == "export" :
636+ export (cfg )
637+ elif cfg .mode == "infer" :
638+ inference (cfg )
529639 else :
530- raise ValueError (f"cfg.mode should in ['train', 'eval'], but got '{ cfg .mode } '" )
640+ raise ValueError (
641+ f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{ cfg .mode } '"
642+ )
531643
532644
533645if __name__ == "__main__" :
0 commit comments