Skip to content

Commit 80249a4

Browse files
author
pfinashx
committed
Applying comments v2
1 parent d13d518 commit 80249a4

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

external/anomaly/tests/conftest.py

+10
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ def ote_templates_root_dir_fx():
4141
logger.debug(f'overloaded ote_templates_root_dir_fx: return {root}')
4242
return root
4343

44+
@pytest.fixture(scope='session')
45+
def ote_reference_root_dir_fx():
46+
import os.path as osp
47+
import logging
48+
logger = logging.getLogger(__name__)
49+
root = osp.dirname(osp.dirname(osp.realpath(__file__)))
50+
root = f'{root}/tests/reference/'
51+
logger.debug(f'overloaded ote_reference_root_dir_fx: return {root}')
52+
return root
53+
4454
# pytest magic
4555
def pytest_generate_tests(metafunc):
4656
ote_pytest_generate_tests_insertion(metafunc)

external/anomaly/tests/test_ote_training.py

+34-3
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@
5858
OTETestNNCFAction,
5959
OTETestNNCFEvaluationAction,
6060
OTETestNNCFExportAction,
61-
OTETestNNCFExportEvaluationAction)
61+
OTETestNNCFExportEvaluationAction,
62+
OTETestNNCFGraphAction)
6263

6364
logger = get_logger(__name__)
6465

@@ -119,6 +120,7 @@ def get_anomaly_test_action_classes() -> List[Type[BaseOTETestAction]]:
119120
OTETestNNCFEvaluationAction,
120121
OTETestNNCFExportAction,
121122
OTETestNNCFExportEvaluationAction,
123+
OTETestNNCFGraphAction,
122124
]
123125

124126

@@ -294,7 +296,8 @@ def get_list_of_tests(cls, usecase: Optional[str] = None):
294296

295297
@pytest.fixture
296298
def params_factories_for_test_actions_fx(self, current_test_parameters_fx,
297-
dataset_definitions_fx, template_paths_fx) -> Dict[str,Callable[[], Dict]]:
299+
dataset_definitions_fx,ote_current_reference_dir_fx,
300+
template_paths_fx) -> Dict[str,Callable[[], Dict]]:
298301
logger.debug('params_factories_for_test_actions_fx: begin')
299302

300303
test_parameters = deepcopy(current_test_parameters_fx)
@@ -323,8 +326,36 @@ def _training_params_factory() -> Dict:
323326
'patience': patience,
324327
'batch_size': batch_size,
325328
}
329+
330+
def _nncf_graph_params_factory() -> Dict:
331+
if dataset_definitions is None:
332+
pytest.skip('The parameter "--dataset-definitions" is not set')
333+
334+
model_name = test_parameters['model_name']
335+
dataset_name = test_parameters['dataset_name']
336+
337+
dataset_params = _get_dataset_params_from_dataset_definitions(dataset_definitions, dataset_name)
338+
339+
if model_name not in template_paths:
340+
raise ValueError(f'Model {model_name} is absent in template_paths, '
341+
f'template_paths.keys={list(template_paths.keys())}')
342+
template_path = make_path_be_abs(template_paths[model_name], template_paths[ROOT_PATH_KEY])
343+
344+
logger.debug('training params factory: Before creating dataset and labels_schema')
345+
dataset, labels_schema = _create_anomaly_classification_dataset_and_labels_schema(dataset_params, dataset_name)
346+
logger.debug('training params factory: After creating dataset and labels_schema')
347+
348+
return {
349+
'dataset': dataset,
350+
'labels_schema': labels_schema,
351+
'template_path': template_path,
352+
'reference_dir': ote_current_reference_dir_fx,
353+
'fn_get_compressed_model': None #NNCF not yet implemented in Anomaly
354+
}
355+
326356
params_factories_for_test_actions = {
327-
'training': _training_params_factory
357+
'training': _training_params_factory,
358+
'nncf_graph': _nncf_graph_params_factory,
328359
}
329360
logger.debug('params_factories_for_test_actions_fx: end')
330361
return params_factories_for_test_actions

0 commit comments

Comments
 (0)