diff --git a/ppsci/solver/solver.py b/ppsci/solver/solver.py index 4a6a7bd2f0..4340926698 100644 --- a/ppsci/solver/solver.py +++ b/ppsci/solver/solver.py @@ -21,6 +21,7 @@ import os import sys from os import path as osp +from typing import TYPE_CHECKING from typing import Callable from typing import Dict from typing import List @@ -41,7 +42,6 @@ from paddle import optimizer as optim from paddle.distributed import fleet from paddle.framework import core -from paddle.static import InputSpec from typing_extensions import Literal import ppsci @@ -51,6 +51,9 @@ from ppsci.utils import misc from ppsci.utils import save_load +if TYPE_CHECKING: + from paddle.static import InputSpec + class Solver: """Class for solver. @@ -729,7 +732,11 @@ def predict( @misc.run_on_eval_mode def export( - self, input_spec: List[InputSpec], export_path: str, with_onnx: bool = False + self, + input_spec: List["InputSpec"], + export_path: str, + with_onnx: bool = False, + skip_prune_program: bool = False, ): """ Convert model to static graph model and export to files. @@ -739,7 +746,9 @@ def export( of the model input. export_path (str): The path prefix to save model. with_onnx (bool, optional): Whether to export model into onnx after - paddle inference models are exported. + paddle inference models are exported. Defaults to False. + skip_prune_program (bool, optional): Whether prune program, pruning program + may cause unexpectable result, e.g. llm-inference. Defaults to False. """ jit.enable_to_static(True) @@ -760,7 +769,7 @@ def export( if len(osp.dirname(export_path)): os.makedirs(osp.dirname(export_path), exist_ok=True) try: - jit.save(static_model, export_path) + jit.save(static_model, export_path, skip_prune_program=skip_prune_program) except Exception as e: raise e logger.message(