-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
82 lines (73 loc) · 2.61 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import sys
sys.path.append(os.path.abspath(os.path.join(__file__, '../')))
if 'OMP_NUM_THREADS' not in os.environ:
os.environ['OMP_NUM_THREADS'] = '8'
import sys
import argparse
import socket
from contextlib import closing
def parse_args():
parser = argparse.ArgumentParser(description='Train a model')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
parser.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='ids of gpus to use')
parser.add_argument('--seed', type=int, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
args = parser.parse_args()
return args
def args_to_str(args):
argv = [args.config]
if args.work_dir is not None:
argv += ['--work-dir', args.work_dir]
if args.resume_from is not None:
argv += ['--resume-from', args.resume_from]
if args.no_validate:
argv.append('--no-validate')
if args.seed is not None:
argv += ['--seed', str(args.seed)]
if args.deterministic:
argv.append('--deterministic')
return argv
def main():
args = parse_args()
if args.gpu_ids is not None:
gpu_ids = args.gpu_ids
elif 'CUDA_VISIBLE_DEVICES' in os.environ:
gpu_ids = [int(i) for i in os.environ['CUDA_VISIBLE_DEVICES'].split(',')]
else:
gpu_ids = [0]
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in gpu_ids])
if len(gpu_ids) == 1:
import tools.train
sys.argv = [''] + args_to_str(args)
tools.train.main()
else:
from torch.distributed import launch
for port in range(29500, 65536):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
res = sock.connect_ex(('localhost', port))
if res != 0:
break
os.environ['training_script'] = './tools/train.py'
sys.argv = ['',
'--nproc_per_node={}'.format(len(gpu_ids)),
'--master_port={}'.format(port),
'./tools/train.py'
] + args_to_str(args) + ['--launcher', 'pytorch', '--diff_seed']
launch.main()
if __name__ == '__main__':
main()