2121import os
2222import sys
2323from os import path as osp
24+ from typing import TYPE_CHECKING
2425from typing import Callable
2526from typing import Dict
2627from typing import List
4142from paddle import optimizer as optim
4243from paddle .distributed import fleet
4344from paddle .framework import core
44- from paddle .static import InputSpec
4545from typing_extensions import Literal
4646
4747import ppsci
5151from ppsci .utils import misc
5252from ppsci .utils import save_load
5353
54+ if TYPE_CHECKING :
55+ from paddle .static import InputSpec
56+
5457
5558class Solver :
5659 """Class for solver.
@@ -729,7 +732,11 @@ def predict(
729732
730733 @misc .run_on_eval_mode
731734 def export (
732- self , input_spec : List [InputSpec ], export_path : str , with_onnx : bool = False
735+ self ,
736+ input_spec : List ["InputSpec" ],
737+ export_path : str ,
738+ with_onnx : bool = False ,
739+ skip_prune_program : bool = False ,
733740 ):
734741 """
735742 Convert model to static graph model and export to files.
@@ -739,7 +746,9 @@ def export(
739746 of the model input.
740747 export_path (str): The path prefix to save model.
741748 with_onnx (bool, optional): Whether to export model into onnx after
742- paddle inference models are exported.
749+ paddle inference models are exported. Defaults to False.
750+ skip_prune_program (bool, optional): Whether prune program, pruning program
751+ may cause unexpectable result, e.g. llm-inference. Defaults to False.
743752 """
744753 jit .enable_to_static (True )
745754
@@ -760,7 +769,7 @@ def export(
760769 if len (osp .dirname (export_path )):
761770 os .makedirs (osp .dirname (export_path ), exist_ok = True )
762771 try :
763- jit .save (static_model , export_path )
772+ jit .save (static_model , export_path , skip_prune_program = skip_prune_program )
764773 except Exception as e :
765774 raise e
766775 logger .message (
0 commit comments