11
11
import torch
12
12
from mmcv .utils import ConfigDict
13
13
from segmentation_tasks .apis .segmentation .config_utils import remove_from_config
14
- from segmentation_tasks .apis .segmentation .ote_utils import TrainingProgressCallback
14
+ from segmentation_tasks .apis .segmentation .ote_utils import TrainingProgressCallback , InferenceProgressCallback
15
15
from segmentation_tasks .extension .utils .hooks import OTELoggerHook
16
16
from mpa import MPAConstants
17
17
from mpa_tasks .apis import BaseTask , TrainType
22
22
from ote_sdk .configuration .helper .utils import ids_to_strings
23
23
from ote_sdk .entities .datasets import DatasetEntity
24
24
from ote_sdk .entities .inference_parameters import InferenceParameters
25
+ from ote_sdk .entities .inference_parameters import default_progress_callback as default_infer_progress_callback
25
26
from ote_sdk .entities .label import Domain
26
27
from ote_sdk .entities .metrics import (CurveMetric , InfoMetric , LineChartInfo ,
27
28
MetricsGroup , Performance , ScoreMetric ,
48
49
create_annotation_from_segmentation_map ,
49
50
create_hard_prediction_from_soft_prediction )
50
51
51
- # from mmdet.apis import export_model
52
-
53
52
54
53
logger = get_logger ()
55
54
@@ -70,12 +69,14 @@ def infer(self,
70
69
logger .info ('infer()' )
71
70
72
71
if inference_parameters is not None :
73
- # update_progress_callback = inference_parameters.update_progress
72
+ update_progress_callback = inference_parameters .update_progress
74
73
is_evaluation = inference_parameters .is_evaluation
75
74
else :
76
- # update_progress_callback = default_infer_progress_callback
75
+ update_progress_callback = default_infer_progress_callback
77
76
is_evaluation = False
78
77
78
+ self ._time_monitor = InferenceProgressCallback (len (dataset ), update_progress_callback )
79
+
79
80
stage_module = 'SegInferrer'
80
81
self ._data_cfg = self ._init_test_data_cfg (dataset )
81
82
self ._label_dictionary = dict (enumerate (self ._labels , 1 ))
@@ -187,8 +188,10 @@ def _init_test_data_cfg(self, dataset: DatasetEntity):
187
188
data_cfg = ConfigDict (
188
189
data = ConfigDict (
189
190
train = ConfigDict (
190
- ote_dataset = None ,
191
- labels = self ._labels ,
191
+ dataset = ConfigDict (
192
+ ote_dataset = None ,
193
+ labels = self ._labels ,
194
+ )
192
195
),
193
196
test = ConfigDict (
194
197
ote_dataset = dataset ,
@@ -311,7 +314,7 @@ def cancel_training(self):
311
314
will therefore take some time.
312
315
"""
313
316
logger .info ("Cancel training requested." )
314
- # self._should_stop = True
317
+ self ._should_stop = True
315
318
# stop_training_filepath = os.path.join(self._training_work_dir, '.stop_training')
316
319
# open(stop_training_filepath, 'a').close()
317
320
if self .cancel_interface is not None :
@@ -325,6 +328,14 @@ def train(self,
325
328
output_model : ModelEntity ,
326
329
train_parameters : Optional [TrainParameters ] = None ):
327
330
logger .info ('train()' )
331
+ # Check for stop signal between pre-eval and training.
332
+ # If training is cancelled at this point,
333
+ if self ._should_stop :
334
+ logger .info ('Training cancelled.' )
335
+ self ._should_stop = False
336
+ self ._is_training = False
337
+ return
338
+
328
339
# Set OTE LoggerHook & Time Monitor
329
340
if train_parameters is not None :
330
341
update_progress_callback = train_parameters .update_progress
@@ -336,8 +347,17 @@ def train(self,
336
347
# learning_curves = defaultdict(OTELoggerHook.Curve)
337
348
stage_module = 'SegTrainer'
338
349
self ._data_cfg = self ._init_train_data_cfg (dataset )
350
+ self ._is_training = True
339
351
results = self ._run_task (stage_module , mode = 'train' , dataset = dataset , parameters = train_parameters )
340
352
353
+ # Check for stop signal when training has stopped.
354
+ # If should_stop is true, training was cancelled and no new
355
+ if self ._should_stop :
356
+ logger .info ('Training cancelled.' )
357
+ self ._should_stop = False
358
+ self ._is_training = False
359
+ return
360
+
341
361
# get output model
342
362
model_ckpt = results .get ('final_ckpt' )
343
363
if model_ckpt is None :
@@ -358,6 +378,7 @@ def train(self,
358
378
self .save_model (output_model )
359
379
output_model .performance = performance
360
380
# output_model.model_status = ModelStatus.SUCCESS
381
+ self ._is_training = False
361
382
logger .info ('train done.' )
362
383
363
384
def _init_train_data_cfg (self , dataset : DatasetEntity ):
0 commit comments