Skip to content

Commit c99aebd

Browse files
committed
added tests to cover get_data_cfg function and StopLossNanTrainingHook after_train_iter method
1 parent 6a7df1c commit c99aebd

File tree

4 files changed

+53
-1
lines changed

4 files changed

+53
-1
lines changed

external/mmdetection/detection_tasks/apis/detection/config_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def cluster_anchors(config: Config, dataset: DatasetEntity, model: BaseDetector)
353353
return config, model
354354

355355

356+
@check_input_parameters_type()
356357
def get_data_cfg(config: Config, subset: str = 'train') -> Config:
357358
data_cfg = config.data[subset]
358359
while 'dataset' in data_cfg:

external/mmdetection/tests/ote_params_validation/test_ote_config_utils_params_validation.py

+31
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
cluster_anchors,
1010
config_from_string,
1111
config_to_string,
12+
get_data_cfg,
1213
is_epoch_based_runner,
1314
patch_adaptive_repeat_dataset,
1415
patch_config,
@@ -416,3 +417,33 @@ def test_cluster_anchors_input_params_validation(self):
416417
unexpected_values=unexpected_values,
417418
class_or_function=cluster_anchors,
418419
)
420+
421+
@e2e_pytest_unit
422+
def test_get_data_cfg_input_params_validation(self):
423+
"""
424+
<b>Description:</b>
425+
Check "get_data_cfg" function input parameters validation
426+
427+
<b>Input data:</b>
428+
"get_data_cfg" function input parameters with unexpected type
429+
430+
<b>Expected results:</b>
431+
Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
432+
"get_data_cfg" function
433+
"""
434+
correct_values_dict = {
435+
"config": Config(),
436+
}
437+
unexpected_int = 1
438+
unexpected_values = [
439+
# Unexpected integer is specified as "config" parameter
440+
("config", unexpected_int),
441+
# Unexpected integer is specified as "subset" parameter
442+
("subset", unexpected_int),
443+
]
444+
445+
check_value_error_exception_raised(
446+
correct_parameters=correct_values_dict,
447+
unexpected_values=unexpected_values,
448+
class_or_function=get_data_cfg,
449+
)

external/mmdetection/tests/ote_params_validation/test_ote_hooks_params_validation.py

+20
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
FixedMomentumUpdaterHook,
1313
OTELoggerHook,
1414
OTEProgressHook,
15+
StopLossNanTrainingHook,
1516
ReduceLROnPlateauLrUpdaterHook,
1617
)
1718
from mmcv.runner import EpochBasedRunner
@@ -532,3 +533,22 @@ def test_reduce_lr_hook_before_run_params_validation(self):
532533
hook = self.hook()
533534
with pytest.raises(ValueError):
534535
hook.before_run(runner="unexpected string") # type: ignore
536+
537+
538+
class TestStopLossNanTrainingHook:
539+
@e2e_pytest_unit
540+
def test_stop_loss_nan_train_hook_after_train_iter_params_validation(self):
541+
"""
542+
<b>Description:</b>
543+
Check StopLossNanTrainingHook object "after_train_iter" method input parameters validation
544+
545+
<b>Input data:</b>
546+
StopLossNanTrainingHook object, "runner" non-BaseRunner type object
547+
548+
<b>Expected results:</b>
549+
Test passes if ValueError exception is raised when unexpected type object is specified as
550+
input parameter for "after_train_iter" method
551+
"""
552+
hook = StopLossNanTrainingHook()
553+
with pytest.raises(ValueError):
554+
hook.after_train_iter(runner="unexpected string") # type: ignore

external/mmdetection/tests/ote_params_validation/test_ote_train_task_params_validation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def model():
3434
)
3535

3636
@e2e_pytest_unit
37-
def test_train_task_input_params_validation(self):
37+
def test_train_task_train_input_params_validation(self):
3838
"""
3939
<b>Description:</b>
4040
Check OTEDetectionTrainingTask object "train" method input parameters validation

0 commit comments

Comments
 (0)