diff --git a/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py b/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py index 2653bf643d74b..f71e5df787aa2 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py @@ -65,7 +65,7 @@ def add_skip_case( self.skip_cases.append((teller, reason, note)) @abc.abstractmethod - def is_program_validity(self, program_config: ProgramConfig) -> bool: + def is_program_valid(self, program_config: ProgramConfig) -> bool: raise NotImplementedError def run_test_config(self, model, params, prog_config, pred_config, diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_conv2d.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_conv2d.py index 4045350dafcbf..1456e9caba690 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_conv2d.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_conv2d.py @@ -21,7 +21,7 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest): - def is_program_validity(self, program_config: ProgramConfig) -> bool: + def is_program_valid(self, program_config: ProgramConfig) -> bool: # TODO: This is just the example to remove the wrong attrs. inputs = program_config.inputs weights = program_config.weights @@ -110,10 +110,6 @@ def generate_weight1(attrs: List[Dict[str, Any]]): }, outputs=["relu_output_data"]) - # if program is invalid, we should skip that cases. - if not self.is_program_validity(program_config): - continue - yield program_config def sample_predictor_configs( diff --git a/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py b/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py index 1f073a99fd3c2..24eff6cbe194d 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py @@ -164,6 +164,10 @@ def run_test(self, quant=False): status = True for prog_config in self.sample_program_configs(): + # if program is invalid, we should skip that cases. + if not self.is_program_valid(prog_config): + continue + model, params = create_fake_model(prog_config) if quant: model, params = create_quant_model(model, params) @@ -208,6 +212,7 @@ def run_test(self, quant=False): inference_config_str(pred_config)) else: raise NotImplementedError + break try: results.append(