-
Notifications
You must be signed in to change notification settings - Fork 19
/
main.py
281 lines (232 loc) · 10.8 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import os
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import argparse
import datetime
import shutil
from pathlib import Path
from utils.config import get_config
from utils.optimizer import build_optimizer, build_scheduler
from utils.tools import AverageMeter, reduce_tensor, epoch_saving, load_checkpoint, generate_text, auto_resume_helper
from datasets.build import build_dataloader
from utils.logger import create_logger
import time
import numpy as np
import random
from apex import amp
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from datasets.blending import CutmixMixupBlending
from utils.config import get_config
from trainers import vificlip
def parse_option():
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-cfg', required=True, type=str, default='configs/k400/32_8.yaml')
parser.add_argument(
"--opts",
help="Modify config options by adding 'KEY VALUE' pairs. ",
default=None,
nargs='+',
)
parser.add_argument('--output', type=str, default="exp")
parser.add_argument('--resume', type=str)
parser.add_argument('--pretrained', type=str)
parser.add_argument('--only_test', action='store_true')
parser.add_argument('--batch-size', type=int)
parser.add_argument('--accumulation-steps', type=int)
parser.add_argument("--local_rank", type=int, default=-1, help='local rank for DistributedDataParallel')
args = parser.parse_args()
config = get_config(args)
return args, config
def main(config):
train_data, val_data, train_loader, val_loader = build_dataloader(logger, config)
class_names = [class_name for i, class_name in train_data.classes]
# Custom trainer for different variants of ViFi-CLIP
model = vificlip.returnCLIP(config,
logger=logger,
class_names=class_names,)
model = model.cuda() # changing to cuda here
mixup_fn = None
if config.AUG.MIXUP > 0:
criterion = SoftTargetCrossEntropy()
mixup_fn = CutmixMixupBlending(num_classes=config.DATA.NUM_CLASSES,
smoothing=config.AUG.LABEL_SMOOTH,
mixup_alpha=config.AUG.MIXUP,
cutmix_alpha=config.AUG.CUTMIX,
switch_prob=config.AUG.MIXUP_SWITCH_PROB)
elif config.AUG.LABEL_SMOOTH > 0:
criterion = LabelSmoothingCrossEntropy(smoothing=config.AUG.LABEL_SMOOTH)
else:
criterion = nn.CrossEntropyLoss()
optimizer = build_optimizer(config, model)
lr_scheduler = build_scheduler(config, optimizer, len(train_loader))
if config.TRAIN.OPT_LEVEL != 'O0':
model, optimizer = amp.initialize(models=model, optimizers=optimizer, opt_level=config.TRAIN.OPT_LEVEL)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False,
find_unused_parameters=False)
start_epoch, max_accuracy = 0, 0.0
if config.TRAIN.AUTO_RESUME:
resume_file = auto_resume_helper(config.OUTPUT)
if resume_file:
config.defrost()
config.MODEL.RESUME = resume_file
config.freeze()
logger.info(f'auto resuming from {resume_file}')
else:
logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
if config.MODEL.RESUME:
start_epoch, max_accuracy = load_checkpoint(config, model, optimizer, lr_scheduler, logger)
if start_epoch > 1:
logger.info("resetting epochs no and max. accuracy to 0 after loading pre-trained weights")
start_epoch = 0
max_accuracy = 0
if config.TEST.ONLY_TEST:
acc1 = validate(val_loader, model, config)
logger.info(f"Accuracy of the network on the {len(val_data)} test videos: {acc1:.1f}%")
return
for epoch in range(start_epoch, config.TRAIN.EPOCHS):
train_loader.sampler.set_epoch(epoch)
train_one_epoch(epoch, model, criterion, optimizer, lr_scheduler, train_loader, config, mixup_fn)
if epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1):
acc1 = validate(val_loader, model, config)
logger.info(f"Accuracy of the network on the {len(val_data)} test videos: {acc1:.1f}%")
is_best = acc1 > max_accuracy
max_accuracy = max(max_accuracy, acc1)
logger.info(f'Max accuracy: {max_accuracy:.2f}%')
if dist.get_rank() == 0 and (
epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1) or is_best):
epoch_saving(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, config.OUTPUT,
is_best)
# Now doing the multi-view inference crop for videos
# 4 CLIPs are obtained from each video, and for each CLIP, we get 3 crops (augmentations)
multi_view_inference = config.TEST.MULTI_VIEW_INFERENCE
if multi_view_inference:
config.defrost()
config.TEST.NUM_CLIP = 4
config.TEST.NUM_CROP = 3
config.freeze()
train_data, val_data, train_loader, val_loader = build_dataloader(logger, config)
acc1 = validate(val_loader, model, config)
logger.info(f"Accuracy of the network on the {len(val_data)} test videos: {acc1:.1f}%")
def train_one_epoch(epoch, model, criterion, optimizer, lr_scheduler, train_loader, config, mixup_fn):
model.train()
optimizer.zero_grad()
num_steps = len(train_loader)
batch_time = AverageMeter()
tot_loss_meter = AverageMeter()
start = time.time()
end = time.time()
for idx, batch_data in enumerate(train_loader):
images = batch_data["imgs"].cuda(non_blocking=True)
label_id = batch_data["label"].cuda(non_blocking=True)
label_id = label_id.reshape(-1)
images = images.view((-1, config.DATA.NUM_FRAMES, 3) + images.size()[-2:])
if mixup_fn is not None:
images, label_id = mixup_fn(images, label_id)
output = model(images)
total_loss = criterion(output, label_id)
total_loss = total_loss / config.TRAIN.ACCUMULATION_STEPS
if config.TRAIN.ACCUMULATION_STEPS == 1:
optimizer.zero_grad()
if config.TRAIN.OPT_LEVEL != 'O0':
with amp.scale_loss(total_loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
total_loss.backward()
if config.TRAIN.ACCUMULATION_STEPS > 1:
if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step_update(epoch * num_steps + idx)
else:
optimizer.step()
lr_scheduler.step_update(epoch * num_steps + idx)
torch.cuda.synchronize()
tot_loss_meter.update(total_loss.item(), len(label_id))
batch_time.update(time.time() - end)
end = time.time()
if idx % config.PRINT_FREQ == 0:
lr = optimizer.param_groups[0]['lr']
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
etas = batch_time.avg * (num_steps - idx)
logger.info(
f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.9f}\t'
f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
f'tot_loss {tot_loss_meter.val:.4f} ({tot_loss_meter.avg:.4f})\t'
f'mem {memory_used:.0f}MB')
epoch_time = time.time() - start
logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
@torch.no_grad()
def validate(val_loader, model, config):
model.eval()
acc1_meter, acc5_meter = AverageMeter(), AverageMeter()
with torch.no_grad():
logger.info(f"{config.TEST.NUM_CLIP * config.TEST.NUM_CROP} views inference")
for idx, batch_data in enumerate(val_loader):
_image = batch_data["imgs"]
label_id = batch_data["label"]
label_id = label_id.reshape(-1)
b, tn, c, h, w = _image.size()
t = config.DATA.NUM_FRAMES
n = tn // t
_image = _image.view(b, n, t, c, h, w)
tot_similarity = torch.zeros((b, config.DATA.NUM_CLASSES)).cuda()
for i in range(n):
image = _image[:, i, :, :, :, :] # [b,t,c,h,w]
label_id = label_id.cuda(non_blocking=True)
image_input = image.cuda(non_blocking=True)
if config.TRAIN.OPT_LEVEL == 'O2':
image_input = image_input.half()
output = model(image_input)
similarity = output.view(b, -1).softmax(dim=-1)
tot_similarity += similarity
values_1, indices_1 = tot_similarity.topk(1, dim=-1)
values_5, indices_5 = tot_similarity.topk(5, dim=-1)
acc1, acc5 = 0, 0
for i in range(b):
if indices_1[i] == label_id[i]:
acc1 += 1
if label_id[i] in indices_5[i]:
acc5 += 1
acc1_meter.update(float(acc1) / b * 100, b)
acc5_meter.update(float(acc5) / b * 100, b)
if idx % config.PRINT_FREQ == 0:
logger.info(
f'Test: [{idx}/{len(val_loader)}]\t'
f'Acc@1: {acc1_meter.avg:.3f}\t'
)
acc1_meter.sync()
acc5_meter.sync()
logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
return acc1_meter.avg
if __name__ == '__main__':
# prepare config
args, config = parse_option()
# init_distributed
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ['WORLD_SIZE'])
print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
else:
rank = -1
world_size = -1
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
torch.distributed.barrier(device_ids=[args.local_rank])
seed = config.SEED + dist.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
# create working_dir
Path(config.OUTPUT).mkdir(parents=True, exist_ok=True)
# logger
logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.ARCH}")
logger.info(f"working dir: {config.OUTPUT}")
# save config
if dist.get_rank() == 0:
logger.info(config)
shutil.copy(args.config, config.OUTPUT)
main(config)