-
Notifications
You must be signed in to change notification settings - Fork 6
/
transfer.py
143 lines (123 loc) · 4.97 KB
/
transfer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from __future__ import division
import argparse
import importlib
import os
import os.path as osp
import time
from polyaxon_client import tracking
import mmcv
import torch
from mmcv import Config
from mmcv.runner import init_dist
from openselfsup import __version__
from openselfsup.apis import set_random_seed, train_model
from openselfsup.datasets import build_dataset
from openselfsup.models import build_model
from openselfsup.utils import collect_env, get_root_logger, traverse_replace
def parse_args():
parser = argparse.ArgumentParser(description='Train a model')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--work_dir',
type=str,
default=None,
help='the dir to save logs and models')
parser.add_argument(
'--resume_from', help='the checkpoint file to resume from')
parser.add_argument(
'--pretrained', default=None, help='pretrained model file')
parser.add_argument(
'--gpus',
type=int,
default=1,
help='number of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--port', type=int, default=29500,
help='port only works when launcher=="slurm"')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
base_path = tracking.get_data_paths()['ceph']+'/'
output_dir = tracking.get_outputs_path()
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# update configs according to CLI args
cfg.work_dir = output_dir
if args.resume_from is not None:
cfg.resume_from = base_path + args.resume_from
cfg.gpus = args.gpus
# check memcached package exists
if importlib.util.find_spec('mc') is None:
traverse_replace(cfg, 'memcached', False)
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
assert cfg.model.type not in \
['DeepCluster', 'MOCO', 'SimCLR', 'ODC', 'NPID'], \
"{} does not support non-dist training.".format(cfg.model.type)
else:
distributed = True
if args.launcher == 'slurm':
cfg.dist_params['port'] = args.port
init_dist(args.launcher, **cfg.dist_params)
cfg.data.train.data_source.list_file = base_path+cfg.data.train.data_source.list_file
cfg.data.train.data_source.root = base_path + cfg.data.train.data_source.root
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, 'train_{}.log'.format(timestamp))
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([('{}: {}'.format(k, v))
for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
# log some basic info
logger.info('Distributed training: {}'.format(distributed))
logger.info('Config:\n{}'.format(cfg.text))
# set random seeds
if args.seed is not None:
logger.info('Set random seed to {}, deterministic: {}'.format(
args.seed, args.deterministic))
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
if args.pretrained is not None:
assert isinstance(args.pretrained, str)
cfg.model.pretrained = args.pretrained
model = build_model(cfg.model)
datasets = [build_dataset(cfg.data.train)]
assert len(cfg.workflow) == 1, "Validation is called by hook."
if cfg.checkpoint_config is not None:
# save openselfsup version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
openselfsup_version=__version__, config=cfg.text)
# add an attribute for visualization convenience
torch.save(model.state_dict(), 'model_best_bacc.pth.tar',
_use_new_zipfile_serialization=False)
if __name__ == '__main__':
main()