File tree Expand file tree Collapse file tree 3 files changed +31
-3
lines changed
Expand file tree Collapse file tree 3 files changed +31
-3
lines changed Original file line number Diff line number Diff line change 2525from typing_extensions import Literal
2626
2727from ppsci .utils import logger
28+ from ppsci .utils import misc
2829
2930if TYPE_CHECKING :
3031 import onnxruntime
@@ -99,15 +100,19 @@ def predict(self, input_dict):
99100 def _create_paddle_predictor (
100101 self ,
101102 ) -> Tuple [paddle_inference .Predictor , paddle_inference .Config ]:
103+ if misc .check_flag_enabled ("FLAGS_enable_pir_api" ):
104+ # NOTE: Using 'json' as suffix instead of 'pdmodel' in PIR mode
105+ self .pdmodel_path = self .pdmodel_path .replace (".pdmodel" , ".json" , 1 )
106+
102107 if not osp .exists (self .pdmodel_path ):
103108 raise FileNotFoundError (
104109 f"Given 'pdmodel_path': { self .pdmodel_path } does not exist. "
105- "Please check if it is correct."
110+ "Please check if cfg.INFER.pdmodel_path is correct."
106111 )
107112 if not osp .exists (self .pdiparams_path ):
108113 raise FileNotFoundError (
109114 f"Given 'pdiparams_path': { self .pdiparams_path } does not exist. "
110- "Please check if it is correct."
115+ "Please check if cfg.INFER.pdiparams_path is correct."
111116 )
112117
113118 config = paddle_inference .Config (self .pdmodel_path , self .pdiparams_path )
Original file line number Diff line number Diff line change @@ -913,11 +913,18 @@ def export(
913913 raise e
914914 logger .message (
915915 f"Inference model has been exported to: { export_path } , including "
916- "*.pdmodel, *.pdiparams and *.pdiparams.info files."
916+ + (
917+ "*.json, *.pdiparams files."
918+ if misc .check_flag_enabled ("FLAGS_enable_pir_api" )
919+ else "*.pdmodel, *.pdiparams and *.pdiparams.info files."
920+ )
917921 )
918922 jit .enable_to_static (False )
919923
920924 if with_onnx :
925+ # TODO: support pir + onnx
926+ if misc .check_flag_enabled ("FLAGS_enable_pir_api" ):
927+ raise ValueError ("paddle2onnx does not support PIR mode yet." )
921928 if not importlib .util .find_spec ("paddle2onnx" ):
922929 raise ModuleNotFoundError (
923930 "Please install paddle2onnx with `pip install paddle2onnx`"
Original file line number Diff line number Diff line change 5252 "run_on_eval_mode" ,
5353 "run_at_rank0" ,
5454 "plot_curve" ,
55+ "check_flag_enabled" ,
5556]
5657
5758
@@ -631,3 +632,18 @@ def plot_curve(
631632 plt .savefig (os .path .join (output_dir , f"{ xlabel } -{ ylabel } _curve.jpg" ), dpi = 200 )
632633 plt .clf ()
633634 plt .close ()
635+
636+
637+ def check_flag_enabled (flag_name : str ) -> bool :
638+ """Check whether the flag is enabled.
639+
640+ Args:
641+ flag_name(str): Flag name to be checked whether enabled or disabled.
642+
643+ Returns:
644+ bool: Whether given flag name is enabled in environment.
645+ """
646+ value = os .getenv (flag_name , False )
647+ if isinstance (value , str ):
648+ return value .lower () in ["true" , "1" ]
649+ return False
You can’t perform that action at this time.
0 commit comments