-
Notifications
You must be signed in to change notification settings - Fork 445
/
Copy pathrunner.py
142 lines (117 loc) · 5.45 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright (C) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
import time
import warnings
from typing import List, Optional, Sequence
import mmcv
import torch.distributed as dist
from mmcv.runner.utils import get_host_info
from mmcv.runner import RUNNERS, EpochBasedRunner, IterBasedRunner, IterLoader, get_dist_info
from torch.utils.data.dataloader import DataLoader
from ote_sdk.utils.argument_checks import check_input_parameters_type
@RUNNERS.register_module()
class EpochRunnerWithCancel(EpochBasedRunner):
"""
Simple modification to EpochBasedRunner to allow cancelling the training during an epoch.
A stopping hook should set the runner.should_stop flag to True if stopping is required.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.should_stop = False
_, world_size = get_dist_info()
self.distributed = True if world_size > 1 else False
def stop(self) -> bool:
""" Returning a boolean to break the training loop
This method supports distributed training by broadcasting should_stop to other ranks
:return: a cancellation bool
"""
broadcast_obj = [False]
if self.rank == 0 and self.should_stop:
broadcast_obj = [True]
if self.distributed:
dist.broadcast_object_list(broadcast_obj, src=0)
if broadcast_obj[0]:
self._max_epochs = self.epoch
return broadcast_obj[0]
@check_input_parameters_type()
def train(self, data_loader: DataLoader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
# TODO: uncomment below line or resolve root cause of deadlock issue if multi-GPUs need to be supported.
# time.sleep(2) # Prevent possible multi-gpu deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
if self.stop():
break
self._iter += 1
self.call_hook('after_train_epoch')
self.stop()
self._epoch += 1
@RUNNERS.register_module()
class IterBasedRunnerWithCancel(IterBasedRunner):
"""
Simple modification to IterBasedRunner to allow cancelling the training. The cancel training hook
should set the runner.should_stop flag to True if stopping is required.
# TODO: Implement cancelling of training via keyboard interrupt signal, instead of should_stop
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.should_stop = False
@check_input_parameters_type()
def main_loop(self, workflow: List[tuple], iter_loaders: Sequence[IterLoader], **kwargs):
while self.iter < self._max_iters:
for i, flow in enumerate(workflow):
self._inner_iter = 0
mode, iters = flow
if not isinstance(mode, str) or not hasattr(self, mode):
raise ValueError(
'runner has no method named "{}" to run a workflow'.
format(mode))
iter_runner = getattr(self, mode)
for _ in range(iters):
if mode == 'train' and self.iter >= self._max_iters:
break
iter_runner(iter_loaders[i], **kwargs)
if self.should_stop:
return
@check_input_parameters_type()
def run(self, data_loaders: Sequence[DataLoader], workflow: List[tuple], max_iters: Optional[int] = None, **kwargs):
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
if max_iters is not None:
warnings.warn('setting max_iters in run is deprecated, '
'please set max_iters in runner_config', DeprecationWarning)
self._max_iters = max_iters
assert self._max_iters is not None, 'max_iters must be specified during instantiation'
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d iters', workflow,
self._max_iters)
self.call_hook('before_run')
iter_loaders = [IterLoader(x) for x in data_loaders]
self.call_hook('before_epoch')
self.should_stop = False
self.main_loop(workflow, iter_loaders, **kwargs)
self.should_stop = False
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_epoch')
self.call_hook('after_run')