Skip to content

Commit 4b58c23

Browse files
author
Songki Choi
authored
Merge pull request #1088 from openvinotoolkit/class-incr-learning-validation
Applying #1084 to develop branch
2 parents 2cf0444 + 4dc5d3a commit 4b58c23

File tree

16 files changed

+1579
-63
lines changed

16 files changed

+1579
-63
lines changed

external/mmsegmentation/tests/test_ote_training.py

+3
Original file line numberDiff line numberDiff line change
@@ -276,5 +276,8 @@ def test(self,
276276
test_parameters,
277277
test_case_fx, data_collector_fx,
278278
cur_test_expected_metrics_callback_fx):
279+
if "18_OCR" in test_parameters["model_name"] \
280+
or "x-mod3_OCR" in test_parameters["model_name"]:
281+
pytest.skip("Known issue CVS-83781")
279282
test_case_fx.run_stage(test_parameters['test_stage'], data_collector_fx,
280283
cur_test_expected_metrics_callback_fx)

external/model-preparation-algorithm/mpa_tasks/apis/classification/task.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from ote_sdk.entities.datasets import DatasetEntity
1717
from ote_sdk.entities.inference_parameters import InferenceParameters, default_progress_callback
18+
from ote_sdk.entities.train_parameters import default_progress_callback as train_default_progress_callback
1819
from ote_sdk.entities.model import ModelEntity, ModelPrecision # ModelStatus
1920
from ote_sdk.entities.resultset import ResultSetEntity
2021
from mmcv.utils import ConfigDict
@@ -217,6 +218,7 @@ def cancel_training(self):
217218
The stopping mechanism allows stopping after each iteration, but validation will still be carried out. Stopping
218219
will therefore take some time.
219220
"""
221+
self._should_stop = True
220222
logger.info("Cancel training requested.")
221223
if self.cancel_interface is not None:
222224
self.cancel_interface.cancel()
@@ -229,17 +231,34 @@ def train(self,
229231
output_model: ModelEntity,
230232
train_parameters: Optional[TrainParameters] = None):
231233
logger.info('train()')
234+
# Check for stop signal between pre-eval and training.
235+
# If training is cancelled at this point,
236+
if self._should_stop:
237+
logger.info('Training cancelled.')
238+
self._should_stop = False
239+
self._is_training = False
240+
return
241+
232242
# Set OTE LoggerHook & Time Monitor
233-
update_progress_callback = default_progress_callback
243+
update_progress_callback = train_default_progress_callback
234244
if train_parameters is not None:
235245
update_progress_callback = train_parameters.update_progress
236246
self._time_monitor = TrainingProgressCallback(update_progress_callback)
237247
self._learning_curves = defaultdict(OTELoggerHook.Curve)
238248

239249
stage_module = 'ClsTrainer'
240250
self._data_cfg = self._init_train_data_cfg(dataset)
251+
self._is_training = True
241252
results = self._run_task(stage_module, mode='train', dataset=dataset, parameters=train_parameters)
242253

254+
# Check for stop signal between pre-eval and training.
255+
# If training is cancelled at this point,
256+
if self._should_stop:
257+
logger.info('Training cancelled.')
258+
self._should_stop = False
259+
self._is_training = False
260+
return
261+
243262
# get output model
244263
model_ckpt = results.get('final_ckpt')
245264
if model_ckpt is None:
@@ -257,6 +276,7 @@ def train(self,
257276
dashboard_metrics=training_metrics)
258277
logger.info(f'Final model performance: {str(performance)}')
259278
output_model.performance = performance
279+
self._is_training = False
260280
logger.info('train done.')
261281

262282
def _init_train_data_cfg(self, dataset: DatasetEntity):

external/model-preparation-algorithm/mpa_tasks/apis/detection/task.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
from mmcv.utils import ConfigDict
1313
from detection_tasks.apis.detection.config_utils import remove_from_config
14-
from detection_tasks.apis.detection.ote_utils import TrainingProgressCallback
14+
from detection_tasks.apis.detection.ote_utils import TrainingProgressCallback, InferenceProgressCallback
1515
from detection_tasks.extension.utils.hooks import OTELoggerHook
1616
from mpa_tasks.apis import BaseTask, TrainType
1717
from mpa_tasks.apis.detection import DetectionConfig
@@ -67,6 +67,11 @@ def infer(self,
6767
) -> DatasetEntity:
6868
logger.info('infer()')
6969

70+
update_progress_callback = default_progress_callback
71+
if inference_parameters is not None:
72+
update_progress_callback = inference_parameters.update_progress
73+
74+
self._time_monitor = InferenceProgressCallback(len(dataset), update_progress_callback)
7075
# If confidence threshold is adaptive then up-to-date value should be stored in the model
7176
# and should not be changed during inference. Otherwise user-specified value should be taken.
7277
if not self._hyperparams.postprocessing.result_based_confidence_threshold:
@@ -75,7 +80,7 @@ def infer(self,
7580

7681
stage_module = 'DetectionInferrer'
7782
self._data_cfg = self._init_test_data_cfg(dataset)
78-
results = self._run_task(stage_module, mode='train', dataset=dataset)
83+
results = self._run_task(stage_module, mode='train', dataset=dataset, parameters=inference_parameters)
7984
# TODO: InferenceProgressCallback register
8085
logger.debug(f'result of run_task {stage_module} module = {results}')
8186
output = results['outputs']
@@ -310,7 +315,7 @@ def cancel_training(self):
310315
will therefore take some time.
311316
"""
312317
logger.info("Cancel training requested.")
313-
# self._should_stop = True
318+
self._should_stop = True
314319
# stop_training_filepath = os.path.join(self._training_work_dir, '.stop_training')
315320
# open(stop_training_filepath, 'a').close()
316321
if self.cancel_interface is not None:
@@ -324,6 +329,14 @@ def train(self,
324329
output_model: ModelEntity,
325330
train_parameters: Optional[TrainParameters] = None):
326331
logger.info('train()')
332+
# Check for stop signal when training has stopped.
333+
# If should_stop is true, training was cancelled and no new
334+
if self._should_stop:
335+
logger.info('Training cancelled.')
336+
self._should_stop = False
337+
self._is_training = False
338+
return
339+
327340
# Set OTE LoggerHook & Time Monitor
328341
update_progress_callback = default_progress_callback
329342
if train_parameters is not None:
@@ -333,8 +346,15 @@ def train(self,
333346

334347
stage_module = 'DetectionTrainer'
335348
self._data_cfg = self._init_train_data_cfg(dataset)
349+
self._is_training = True
336350
results = self._run_task(stage_module, mode='train', dataset=dataset, parameters=train_parameters)
337-
# logger.info(f'result of run_task {stage_module} module = {results}')
351+
352+
# Check for stop signal when training has stopped. If should_stop is true, training was cancelled and no new
353+
if self._should_stop:
354+
logger.info('Training cancelled.')
355+
self._should_stop = False
356+
self._is_training = False
357+
return
338358

339359
# get output model
340360
model_ckpt = results.get('final_ckpt')
@@ -389,6 +409,7 @@ def train(self,
389409
self.save_model(output_model)
390410
output_model.performance = performance
391411
# output_model.model_status = ModelStatus.SUCCESS
412+
self._is_training = False
392413
logger.info('train done.')
393414

394415
def _init_train_data_cfg(self, dataset: DatasetEntity):

external/model-preparation-algorithm/mpa_tasks/apis/segmentation/task.py

+29-8
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
from mmcv.utils import ConfigDict
1313
from segmentation_tasks.apis.segmentation.config_utils import remove_from_config
14-
from segmentation_tasks.apis.segmentation.ote_utils import TrainingProgressCallback
14+
from segmentation_tasks.apis.segmentation.ote_utils import TrainingProgressCallback, InferenceProgressCallback
1515
from segmentation_tasks.extension.utils.hooks import OTELoggerHook
1616
from mpa import MPAConstants
1717
from mpa_tasks.apis import BaseTask, TrainType
@@ -22,6 +22,7 @@
2222
from ote_sdk.configuration.helper.utils import ids_to_strings
2323
from ote_sdk.entities.datasets import DatasetEntity
2424
from ote_sdk.entities.inference_parameters import InferenceParameters
25+
from ote_sdk.entities.inference_parameters import default_progress_callback as default_infer_progress_callback
2526
from ote_sdk.entities.label import Domain
2627
from ote_sdk.entities.metrics import (CurveMetric, InfoMetric, LineChartInfo,
2728
MetricsGroup, Performance, ScoreMetric,
@@ -48,8 +49,6 @@
4849
create_annotation_from_segmentation_map,
4950
create_hard_prediction_from_soft_prediction)
5051

51-
# from mmdet.apis import export_model
52-
5352

5453
logger = get_logger()
5554

@@ -70,12 +69,14 @@ def infer(self,
7069
logger.info('infer()')
7170

7271
if inference_parameters is not None:
73-
# update_progress_callback = inference_parameters.update_progress
72+
update_progress_callback = inference_parameters.update_progress
7473
is_evaluation = inference_parameters.is_evaluation
7574
else:
76-
# update_progress_callback = default_infer_progress_callback
75+
update_progress_callback = default_infer_progress_callback
7776
is_evaluation = False
7877

78+
self._time_monitor = InferenceProgressCallback(len(dataset), update_progress_callback)
79+
7980
stage_module = 'SegInferrer'
8081
self._data_cfg = self._init_test_data_cfg(dataset)
8182
self._label_dictionary = dict(enumerate(self._labels, 1))
@@ -187,8 +188,10 @@ def _init_test_data_cfg(self, dataset: DatasetEntity):
187188
data_cfg = ConfigDict(
188189
data=ConfigDict(
189190
train=ConfigDict(
190-
ote_dataset=None,
191-
labels=self._labels,
191+
dataset=ConfigDict(
192+
ote_dataset=None,
193+
labels=self._labels,
194+
)
192195
),
193196
test=ConfigDict(
194197
ote_dataset=dataset,
@@ -311,7 +314,7 @@ def cancel_training(self):
311314
will therefore take some time.
312315
"""
313316
logger.info("Cancel training requested.")
314-
# self._should_stop = True
317+
self._should_stop = True
315318
# stop_training_filepath = os.path.join(self._training_work_dir, '.stop_training')
316319
# open(stop_training_filepath, 'a').close()
317320
if self.cancel_interface is not None:
@@ -325,6 +328,14 @@ def train(self,
325328
output_model: ModelEntity,
326329
train_parameters: Optional[TrainParameters] = None):
327330
logger.info('train()')
331+
# Check for stop signal between pre-eval and training.
332+
# If training is cancelled at this point,
333+
if self._should_stop:
334+
logger.info('Training cancelled.')
335+
self._should_stop = False
336+
self._is_training = False
337+
return
338+
328339
# Set OTE LoggerHook & Time Monitor
329340
if train_parameters is not None:
330341
update_progress_callback = train_parameters.update_progress
@@ -336,8 +347,17 @@ def train(self,
336347
# learning_curves = defaultdict(OTELoggerHook.Curve)
337348
stage_module = 'SegTrainer'
338349
self._data_cfg = self._init_train_data_cfg(dataset)
350+
self._is_training = True
339351
results = self._run_task(stage_module, mode='train', dataset=dataset, parameters=train_parameters)
340352

353+
# Check for stop signal when training has stopped.
354+
# If should_stop is true, training was cancelled and no new
355+
if self._should_stop:
356+
logger.info('Training cancelled.')
357+
self._should_stop = False
358+
self._is_training = False
359+
return
360+
341361
# get output model
342362
model_ckpt = results.get('final_ckpt')
343363
if model_ckpt is None:
@@ -358,6 +378,7 @@ def train(self,
358378
self.save_model(output_model)
359379
output_model.performance = performance
360380
# output_model.model_status = ModelStatus.SUCCESS
381+
self._is_training = False
361382
logger.info('train done.')
362383

363384
def _init_train_data_cfg(self, dataset: DatasetEntity):

external/model-preparation-algorithm/mpa_tasks/apis/task.py

+11-43
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,6 @@
2929
logger = get_logger()
3030

3131

32-
class _MPAUpdateProgressCallbackWrapper(UpdateProgressCallback):
33-
""" UpdateProgressCallback wrapper
34-
just wrapping the callback instance and provides error free representation as 'pretty_text'
35-
"""
36-
37-
def __init__(self, callback, **kwargs):
38-
if not callable(callback):
39-
raise RuntimeError(f'cannot accept a not callable object!! {callback}')
40-
self._callback = callback
41-
super().__init__(**kwargs)
42-
43-
def __repr__(self):
44-
return f"'{__name__}._MPAUpdateProgressCallbackWrapper'"
45-
46-
def __reduce__(self):
47-
return (self.__class__, (id(self),))
48-
49-
def __call__(self, progress: float, score: Optional[float] = None):
50-
self._callback(progress, score)
51-
52-
5332
class BaseTask:
5433
def __init__(self, task_config, task_environment: TaskEnvironment):
5534
self._task_config = task_config
@@ -83,6 +62,8 @@ def __init__(self, task_config, task_environment: TaskEnvironment):
8362
self._mode = None
8463
self._time_monitor = None
8564
self._learning_curves = None
65+
self._is_training = False
66+
self._should_stop = False
8667
self.cancel_interface = None
8768
self.reserved_cancel = False
8869
self.on_hook_initialized = self.OnHookInitialized(self)
@@ -104,30 +85,9 @@ def _run_task(self, stage_module, mode=None, dataset=None, parameters=None, **kw
10485
raise RuntimeError(
10586
"'recipe_cfg' is not initialized yet."
10687
"call prepare() method before calling this method")
107-
# self._stage_module = stage_module
88+
10889
if mode is not None:
10990
self._mode = mode
110-
if parameters is not None:
111-
if isinstance(parameters, TrainParameters):
112-
hook_name = 'TrainProgressUpdateHook'
113-
progress_callback = _MPAUpdateProgressCallbackWrapper(parameters.update_progress)
114-
# TODO: update recipe to do RESUME
115-
if parameters.resume:
116-
pass
117-
elif isinstance(parameters, InferenceParameters):
118-
hook_name = 'InferenceProgressUpdateHook'
119-
progress_callback = _MPAUpdateProgressCallbackWrapper(parameters.update_progress)
120-
else:
121-
hook_name = 'ProgressUpdateHook'
122-
progress_callback = None
123-
logger.info(f'progress callback = {progress_callback}, hook name = {hook_name}')
124-
if progress_callback is not None:
125-
progress_update_hook_cfg = ConfigDict(
126-
type='ProgressUpdateHook',
127-
name=hook_name,
128-
callback=progress_callback
129-
)
130-
update_or_add_custom_hook(self._recipe_cfg, progress_update_hook_cfg)
13191

13292
common_cfg = ConfigDict(dict(output_path=self._output_path))
13393

@@ -152,6 +112,14 @@ def finalize(self):
152112
if os.path.exists(self._output_path):
153113
shutil.rmtree(self._output_path, ignore_errors=False)
154114

115+
def _delete_scratch_space(self):
116+
"""
117+
Remove model checkpoints and mpa logs
118+
"""
119+
120+
if os.path.exists(self._output_path):
121+
shutil.rmtree(self._output_path, ignore_errors=False)
122+
155123
def __del__(self):
156124
self.finalize()
157125

Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
[pytest]
2-
python_files = test_*_cls_il.py
2+
python_files = test_ote_*.py
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (C) 2022 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
5+
try:
6+
import os
7+
from e2e import config as config_e2e
8+
9+
config_e2e.repository_name = os.environ.get('TT_REPOSITORY_NAME', 'ote/training_extensions/external/model-preparation-algorithm')
10+
except ImportError:
11+
pass

0 commit comments

Comments
 (0)