Skip to content

Commit 13ef9a7

Browse files
add validator check
1 parent 941228a commit 13ef9a7

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

deploy/python_infer/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def _create_paddle_predictor(
101101
self,
102102
) -> Tuple[paddle_inference.Predictor, paddle_inference.Config]:
103103
if misc.check_flag_enabled("FLAGS_enable_pir_api"):
104-
# PIR mode
105-
self.pdmodel_path = self.pdmodel_path.replace(".pdmodel", ".json")
104+
# NOTE: Using 'json' as suffix instead of 'pdmodel' in PIR mode
105+
self.pdmodel_path = self.pdmodel_path.replace(".pdmodel", ".json", 1)
106106

107107
if not osp.exists(self.pdmodel_path):
108108
raise FileNotFoundError(
@@ -114,8 +114,8 @@ def _create_paddle_predictor(
114114
f"Given 'pdiparams_path': {self.pdiparams_path} does not exist. "
115115
"Please check if cfg.INFER.pdiparams_path is correct."
116116
)
117-
config = paddle_inference.Config(self.pdmodel_path, self.pdiparams_path)
118117

118+
config = paddle_inference.Config(self.pdmodel_path, self.pdiparams_path)
119119
if self.device == "gpu":
120120
config.enable_use_gpu(self.gpu_mem, self.gpu_id)
121121
if self.engine == "tensorrt":

ppsci/solver/solver.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,11 @@ def __init__(
277277
)
278278
_count[metric_name] = 1
279279
del _count
280+
elif self.eval_during_train:
281+
raise ValueError(
282+
"eval_during_train is enabled but no validator is provided, "
283+
"please provide a validator to enable evaluation during training."
284+
)
280285

281286
# whether set `stop_gradient=True` for every Tensor if no differentiation involved during evaluation
282287
if not cfg:
@@ -911,7 +916,6 @@ def export(
911916
jit.save(static_model, export_path, skip_prune_program=skip_prune_program)
912917
except Exception as e:
913918
raise e
914-
915919
logger.message(
916920
f"Inference model has been exported to: {export_path}, including "
917921
+ (
@@ -920,7 +924,6 @@ def export(
920924
else "*.pdmodel, *.pdiparams and *.pdiparams.info files."
921925
)
922926
)
923-
924927
jit.enable_to_static(False)
925928

926929
if with_onnx:

ppsci/utils/misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -635,13 +635,13 @@ def plot_curve(
635635

636636

637637
def check_flag_enabled(flag_name: str) -> bool:
638-
"""Split given array into single channel array at axis -1 in order of given keys.
638+
"""Check whether the flag is enabled.
639639
640640
Args:
641641
flag_name(str): Flag name to be checked whether enabled or disabled.
642642
643643
Returns:
644-
bool: Whether the flag is enabled.
644+
bool: Whether given flag name is enabled in environment.
645645
"""
646646
value = os.getenv(flag_name, False)
647647
if isinstance(value, str):

0 commit comments

Comments
 (0)