Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions deploy/python_infer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing_extensions import Literal

from ppsci.utils import logger
from ppsci.utils import misc

if TYPE_CHECKING:
import onnxruntime
Expand Down Expand Up @@ -99,15 +100,19 @@ def predict(self, input_dict):
def _create_paddle_predictor(
self,
) -> Tuple[paddle_inference.Predictor, paddle_inference.Config]:
if misc.check_flag_enabled("FLAGS_enable_pir_api"):
# NOTE: Using 'json' as suffix instead of 'pdmodel' in PIR mode
self.pdmodel_path = self.pdmodel_path.replace(".pdmodel", ".json", 1)

if not osp.exists(self.pdmodel_path):
raise FileNotFoundError(
f"Given 'pdmodel_path': {self.pdmodel_path} does not exist. "
"Please check if it is correct."
"Please check if cfg.INFER.pdmodel_path is correct."
)
if not osp.exists(self.pdiparams_path):
raise FileNotFoundError(
f"Given 'pdiparams_path': {self.pdiparams_path} does not exist. "
"Please check if it is correct."
"Please check if cfg.INFER.pdiparams_path is correct."
)

config = paddle_inference.Config(self.pdmodel_path, self.pdiparams_path)
Expand Down
9 changes: 8 additions & 1 deletion ppsci/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,11 +913,18 @@ def export(
raise e
logger.message(
f"Inference model has been exported to: {export_path}, including "
"*.pdmodel, *.pdiparams and *.pdiparams.info files."
+ (
"*.json, *.pdiparams files."
if misc.check_flag_enabled("FLAGS_enable_pir_api")
else "*.pdmodel, *.pdiparams and *.pdiparams.info files."
)
)
jit.enable_to_static(False)

if with_onnx:
# TODO: support pir + onnx
if misc.check_flag_enabled("FLAGS_enable_pir_api"):
raise ValueError("paddle2onnx does not support PIR mode yet.")
if not importlib.util.find_spec("paddle2onnx"):
raise ModuleNotFoundError(
"Please install paddle2onnx with `pip install paddle2onnx`"
Expand Down
16 changes: 16 additions & 0 deletions ppsci/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"run_on_eval_mode",
"run_at_rank0",
"plot_curve",
"check_flag_enabled",
]


Expand Down Expand Up @@ -631,3 +632,18 @@ def plot_curve(
plt.savefig(os.path.join(output_dir, f"{xlabel}-{ylabel}_curve.jpg"), dpi=200)
plt.clf()
plt.close()


def check_flag_enabled(flag_name: str) -> bool:
"""Check whether the flag is enabled.

Args:
flag_name(str): Flag name to be checked whether enabled or disabled.

Returns:
bool: Whether given flag name is enabled in environment.
"""
value = os.getenv(flag_name, False)
if isinstance(value, str):
return value.lower() in ["true", "1"]
return False