diff --git a/examples/nowcastnet/conf/nowcastnet.yaml b/examples/nowcastnet/conf/nowcastnet.yaml index 29ccec60e..1d4bcef7c 100644 --- a/examples/nowcastnet/conf/nowcastnet.yaml +++ b/examples/nowcastnet/conf/nowcastnet.yaml @@ -11,6 +11,7 @@ hydra: - TRAIN.checkpoint_path - TRAIN.pretrained_model_path - EVAL.pretrained_model_path + - INFER.pretrained_model_path - mode - output_dir - log_freq @@ -55,3 +56,6 @@ MODEL: # evaluation settings EVAL: pretrained_model_path: checkpoints/paddle_mrms_model + +INFER: + pretrained_model_path: checkpoints/paddle_mrms_model diff --git a/examples/nowcastnet/nowcastnet.py b/examples/nowcastnet/nowcastnet.py index 6f3fbde79..c9f9f2774 100644 --- a/examples/nowcastnet/nowcastnet.py +++ b/examples/nowcastnet/nowcastnet.py @@ -82,12 +82,40 @@ def evaluate(cfg: DictConfig): solver.visualize(batch_id) +def export(cfg: DictConfig): + if cfg.CASE_TYPE == "large": + model_cfg = cfg.MODEL.large + elif cfg.CASE_TYPE == "normal": + model_cfg = cfg.MODEL.normal + else: + raise ValueError( + f"cfg.CASE_TYPE should in ['normal', 'large'], but got '{cfg.mode}'" + ) + model = ppsci.arch.NowcastNet(**model_cfg) + + # initialize solver + solver = ppsci.solver.Solver( + model, + pretrained_model_path=cfg.INFER.pretrained_model_path, + ) + + # export model + from paddle.static import InputSpec + + input_spec = [ + {key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys}, + ] + solver.export(input_spec, cfg.INFER.export_path) + + @hydra.main(version_base=None, config_path="./conf", config_name="nowcastnet.yaml") def main(cfg: DictConfig): if cfg.mode == "train": train(cfg) elif cfg.mode == "eval": evaluate(cfg) + elif cfg.mode == "export": + export(cfg) else: raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")