-
Notifications
You must be signed in to change notification settings - Fork 2
/
training.py
181 lines (153 loc) · 6.87 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import time
import torch
def train(train_loader, model, criterion, optimizer, epoch, log, args):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
# create tuples for training
avg_neg_distance = train_loader.dataset.create_epoch_tuples(model)
# switch to train mode
model.train()
model.apply(set_batchnorm_eval)
# zero out gradients
optimizer.zero_grad()
end = time.time()
for i, (input, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
nq = len(input) # number of training tuples
ni = len(input[0]) # number of images per tuple
if args.mode == 'rand':
outputs = torch.zeros(nq, target[0].shape[0]).cuda()
for q in range(nq):
ni = len(input[q])
output = torch.zeros(model.meta['outputdim'], ni).cuda()
if args.mode in ['ts', 'ts_self', 'ts_rand', 'reg', 'reg_only_pos', 'rand_tpl_a']:
if args.sym == True:
for imi in range(ni):
output[:, imi] = model(input[q][imi].cuda()).squeeze()
else:
for imi in range(ni):
if imi == 0:
output[:, imi] = model(input[q][imi].cuda()).squeeze()
else:
output[:, imi] = torch.tensor(input[q][imi]).float().cuda()
elif args.mode in ['std', 'rand_tpl']:
for imi in range(ni):
output[:, imi] = model(input[q][imi].cuda()).squeeze()
else:
for imi in range(ni):
output[:, imi] = model(input[q][imi].cuda()).squeeze()
outputs[q,:] = output.squeeze()
if args.mode != 'rand':
loss = criterion(output, target[q].t().cuda())
losses.update(loss.item())
loss.backward()
if args.mode == 'rand':
targets = torch.stack(target).cuda()
loss = criterion(outputs, targets)
losses.update(loss.item())
loss.backward()
if (i + 1) % args.update_every == 0:
# do one step for multiple batches
# accumulated gradients are used
optimizer.step()
# zero out gradients so we can
# accumulate new ones over batches
optimizer.zero_grad()
# print('>> Train: [{0}][{1}/{2}]\t'
# 'Weight update performed'.format(
# epoch+1, i+1, len(train_loader)))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if (i+1) % args.print_freq == 0 or i == 0 or (i+1) == len(train_loader):
out = '>> Train: [{0}][{1}/{2}]\t Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t \
Data {data_time.val:.3f} ({data_time.avg:.3f})\t \
Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
epoch+1, i+1, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses)
print(out)
log.write(out+'\n')
return losses.avg
def validate(val_loader, model, criterion, epoch, args):
batch_time = AverageMeter()
losses = AverageMeter()
# create tuples for validation
avg_neg_distance = val_loader.dataset.create_epoch_tuples(model)
# switch to evaluate mode
model.eval()
end = time.time()
for i, (input, target) in enumerate(val_loader):
nq = len(input) # number of training tuples
ni = len(input[0]) # number of images per tuple
output = torch.zeros(model.meta['outputdim'], nq*ni).cuda()
if args.mode == 'rand':
outputs = torch.zeros(nq, target[0].shape[0]).cuda()
for q in range(nq):
if args.mode in ['ts', 'reg', 'reg_only_pos', 'ts_self', 'ts_rand', 'rand_tpl_a']:
if args.sym == True:
for imi in range(ni):
output[:, q*ni + imi] = model(input[q][imi].cuda()).squeeze()
else:
for imi in range(ni):
if imi == 0:
output[:, q*ni + imi] = model(input[q][imi].cuda()).squeeze()
else:
output[:, q*ni + imi] = torch.tensor(input[q][imi]).float().cuda()
elif args.mode == 'rand':
for imi in range(ni):
output[:, imi] = model(input[q][imi].cuda()).squeeze()
else:
for imi in range(ni):
# compute output vector for image imi of query q
output[:, q*ni + imi] = model(input[q][imi].cuda()).squeeze()
# no need to reduce memory consumption (no backward pass):
# compute loss for the full batch
if args.mode == 'rand':
targets = torch.stack(target).cuda().t()
loss = criterion(output.t(), targets.t())
else:
if args.sym:
loss = criterion(output, torch.cat(target).cuda().t())
else:
loss = criterion(output, torch.cat(target).cuda())
# record loss
losses.update(loss.item()/nq, nq)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if (i+1) % args.print_freq == 0 or i == 0 or (i+1) == len(val_loader):
print('>> Val: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
epoch+1, i+1, len(val_loader), batch_time=batch_time, loss=losses))
return losses.avg
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def set_batchnorm_eval(m):
classname = m.__class__.__name__
if classname.find('BatchNorm') != -1:
# freeze running mean and std:
# we do training one image at a time
# so the statistics would not be per batch
# hence we choose freezing (ie using imagenet statistics)
m.eval()
# # freeze parameters:
# # in fact no need to freeze scale and bias
# # they can be learned
# # that is why next two lines are commented
# for p in m.parameters():
# p.requires_grad = False