File tree Expand file tree Collapse file tree 3 files changed +10
-7
lines changed
Expand file tree Collapse file tree 3 files changed +10
-7
lines changed Original file line number Diff line number Diff line change @@ -101,8 +101,8 @@ def _create_paddle_predictor(
101101 self ,
102102 ) -> Tuple [paddle_inference .Predictor , paddle_inference .Config ]:
103103 if misc .check_flag_enabled ("FLAGS_enable_pir_api" ):
104- # PIR mode
105- self .pdmodel_path = self .pdmodel_path .replace (".pdmodel" , ".json" )
104+ # NOTE: Using 'json' as suffix instead of 'pdmodel' in PIR mode
105+ self .pdmodel_path = self .pdmodel_path .replace (".pdmodel" , ".json" , 1 )
106106
107107 if not osp .exists (self .pdmodel_path ):
108108 raise FileNotFoundError (
@@ -114,8 +114,8 @@ def _create_paddle_predictor(
114114 f"Given 'pdiparams_path': { self .pdiparams_path } does not exist. "
115115 "Please check if cfg.INFER.pdiparams_path is correct."
116116 )
117- config = paddle_inference .Config (self .pdmodel_path , self .pdiparams_path )
118117
118+ config = paddle_inference .Config (self .pdmodel_path , self .pdiparams_path )
119119 if self .device == "gpu" :
120120 config .enable_use_gpu (self .gpu_mem , self .gpu_id )
121121 if self .engine == "tensorrt" :
Original file line number Diff line number Diff line change @@ -277,6 +277,11 @@ def __init__(
277277 )
278278 _count [metric_name ] = 1
279279 del _count
280+ elif self .eval_during_train :
281+ raise ValueError (
282+ "eval_during_train is enabled but no validator is provided, "
283+ "please provide a validator to enable evaluation during training."
284+ )
280285
281286 # whether set `stop_gradient=True` for every Tensor if no differentiation involved during evaluation
282287 if not cfg :
@@ -911,7 +916,6 @@ def export(
911916 jit .save (static_model , export_path , skip_prune_program = skip_prune_program )
912917 except Exception as e :
913918 raise e
914-
915919 logger .message (
916920 f"Inference model has been exported to: { export_path } , including "
917921 + (
@@ -920,7 +924,6 @@ def export(
920924 else "*.pdmodel, *.pdiparams and *.pdiparams.info files."
921925 )
922926 )
923-
924927 jit .enable_to_static (False )
925928
926929 if with_onnx :
Original file line number Diff line number Diff line change @@ -635,13 +635,13 @@ def plot_curve(
635635
636636
637637def check_flag_enabled (flag_name : str ) -> bool :
638- """Split given array into single channel array at axis -1 in order of given keys .
638+ """Check whether the flag is enabled .
639639
640640 Args:
641641 flag_name(str): Flag name to be checked whether enabled or disabled.
642642
643643 Returns:
644- bool: Whether the flag is enabled.
644+ bool: Whether given flag name is enabled in environment .
645645 """
646646 value = os .getenv (flag_name , False )
647647 if isinstance (value , str ):
You can’t perform that action at this time.
0 commit comments