-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathengine.py
138 lines (110 loc) · 4.64 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python3
# encoding: utf-8
# @Time : 2018/8/2 下午3:23
# @Author : yuchangqian
# @Contact : changqian_yu@163.com
# @File : engine.py
import os
import os.path as osp
import time
import argparse
import torch
import torch.distributed as dist
from utils.logger import get_logger
from utils.pyt_utils import parse_devices, all_reduce_tensor, extant_file
try:
from apex.parallel import DistributedDataParallel, SyncBatchNorm
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex .")
logger = get_logger()
class Engine(object):
def __init__(self, custom_parser=None):
logger.info(
"PyTorch Version {}".format(torch.__version__))
self.devices = None
self.distributed = False
if custom_parser is None:
self.parser = argparse.ArgumentParser()
else:
assert isinstance(custom_parser, argparse.ArgumentParser)
self.parser = custom_parser
self.inject_default_parser()
self.args = self.parser.parse_args()
self.continue_state_object = self.args.continue_fpath
# if not self.args.gpu == 'None':
# os.environ["CUDA_VISIBLE_DEVICES"]=self.args.gpu
if 'WORLD_SIZE' in os.environ:
self.distributed = int(os.environ['WORLD_SIZE']) > 1
if self.distributed:
self.local_rank = self.args.local_rank
self.world_size = int(os.environ['WORLD_SIZE'])
torch.cuda.set_device(self.local_rank)
dist.init_process_group(backend="nccl", init_method='env://')
self.devices = [i for i in range(self.world_size)]
else:
gpus = os.environ["CUDA_VISIBLE_DEVICES"]
self.devices = [i for i in range(len(gpus.split(',')))]
def inject_default_parser(self):
p = self.parser
p.add_argument('-d', '--devices', default='',
help='set data parallel training')
p.add_argument('-c', '--continue', type=extant_file,
metavar="FILE",
dest="continue_fpath",
help='continue from one certain checkpoint')
p.add_argument('--local_rank', default=0, type=int,
help='process rank on node')
def data_parallel(self, model):
if self.distributed:
model = DistributedDataParallel(model)
else:
model = torch.nn.DataParallel(model)
return model
def get_train_loader(self, train_dataset):
train_sampler = None
is_shuffle = True
batch_size = self.args.batch_size
if self.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset)
batch_size = self.args.batch_size // self.world_size
is_shuffle = False
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
num_workers=self.args.num_workers,
drop_last=False,
shuffle=is_shuffle,
pin_memory=True,
sampler=train_sampler)
return train_loader, train_sampler
def get_test_loader(self, test_dataset):
test_sampler = None
is_shuffle = False
batch_size = self.args.batch_size
if self.distributed:
test_sampler = torch.utils.data.distributed.DistributedSampler(
test_dataset)
batch_size = self.args.batch_size // self.world_size
test_loader = torch.utils.data.DataLoader(test_dataset,
batch_size=batch_size,
num_workers=self.args.num_workers,
drop_last=False,
shuffle=is_shuffle,
pin_memory=True,
sampler=test_sampler)
return test_loader, test_sampler
def all_reduce_tensor(self, tensor, norm=True):
if self.distributed:
return all_reduce_tensor(tensor, world_size=self.world_size, norm=norm)
else:
return torch.mean(tensor)
def __enter__(self):
return self
def __exit__(self, type, value, tb):
torch.cuda.empty_cache()
if type is not None:
logger.warning(
"A exception occurred during Engine initialization, "
"give up running process")
return False