-
Notifications
You must be signed in to change notification settings - Fork 8
/
train.py
477 lines (386 loc) · 17.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
import argparse
from collections import OrderedDict
import json
import math
import os
import sys
import time
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from data import get_data
from tokenizer import SimpleTokenizer
import models
import losses
import utils
def get_args_parser():
parser = argparse.ArgumentParser(description='LaCLIP/CLIP training', add_help=False)
parser.add_argument(
"--train-data",
type=str,
default=None,
help="Path to training csv.",
)
parser.add_argument(
'--root',
type=str,
default='./data/',
help='Root directory of images.'
)
# list of filenames for augmented captions
parser.add_argument('--augmented_caption_filelist', nargs='+', help='list of augmented caption filenames, seperated by space')
parser.add_argument('--aug-text', action='store_true', help='set to True for LaCLIP')
parser.add_argument('--imagenet-root', default='data/imagenet', type=str, help='path to imagenet dataset')
parser.add_argument('--output-dir', default='./output', type=str, help='output dir')
parser.add_argument('--model', default='CLIP_VITB16', type=str)
parser.add_argument('--resume', default='', type=str, help='path to resume from')
parser.add_argument('--epochs', default=25, type=int)
parser.add_argument('--warmup-epochs', default=1, type=int)
parser.add_argument('--start-epoch', default=0, type=int)
parser.add_argument('--batch-size', default=64, type=int,
help='number of samples per-gpu')
parser.add_argument('--lr', default=3e-3, type=float)
parser.add_argument('--lr-start', default=1e-6, type=float,
help='initial warmup lr')
parser.add_argument('--lr-end', default=1e-5, type=float,
help='minimum final lr')
parser.add_argument('--update-freq', default=1, type=int,
help='optimizer update frequency (i.e. gradient accumulation steps)')
parser.add_argument('--wd', default=0.1, type=float)
parser.add_argument('--betas', default=(0.9, 0.98), nargs=2, type=float)
parser.add_argument('--eps', default=1e-8, type=float)
parser.add_argument('--disable-amp', action='store_true',
help='disable mixed-precision training (requires more memory and compute)')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('-j', '--workers', default=10, type=int, metavar='N',
help='number of data loading workers per process')
parser.add_argument('--world-size', default=1, type=int,
help='number of nodes for distributed training')
parser.add_argument('--rank', default=0, type=int,
help='node rank for distributed training')
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument('--dist-url', default='env://', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str)
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.')
return parser
best_acc1 = 0
def main(args):
utils.init_distributed_mode(args)
global best_acc1
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
# create model
print("=> creating model: {}".format(args.model))
model = getattr(models, args.model)()
model.cuda(args.gpu)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], bucket_cap_mb=200)
# define loss function (criterion) and optimizer
criterion = losses.CLIPLoss().cuda(args.gpu)
p_wd, p_non_wd = [], []
for n, p in model.named_parameters():
if not p.requires_grad:
continue # frozen weights
if p.ndim < 2 or 'bias' in n or 'ln' in n or 'bn' in n:
p_non_wd.append(p)
else:
p_wd.append(p)
optim_params = [{"params": p_wd, "weight_decay": args.wd},
{"params": p_non_wd, "weight_decay": 0}]
optimizer = torch.optim.AdamW(optim_params, lr=args.lr, betas=args.betas,
eps=args.eps, weight_decay=args.wd)
scaler = amp.GradScaler(enabled=not args.disable_amp)
# optionally resume from a checkpoint (takes precedence over autoresume)
if args.resume:
if os.path.isfile(args.resume):
print("=> loading resume checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location='cpu')
epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 0
args.start_epoch = epoch
result = model.load_state_dict(checkpoint['state_dict'], strict=False)
print(result)
optimizer.load_state_dict(checkpoint['optimizer']) if 'optimizer' in checkpoint else ()
scaler.load_state_dict(checkpoint['scaler']) if 'scaler' in checkpoint else ()
best_acc1 = checkpoint['best_acc1']
print("=> loaded resume checkpoint '{}' (epoch {})"
.format(args.resume, epoch))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
else:
# auto-resume from latest checkpoint in output directory
latest = os.path.join(args.output_dir, 'checkpoint.pt')
if os.path.isfile(latest):
print("=> loading latest checkpoint '{}'".format(latest))
latest_checkpoint = torch.load(latest, map_location='cpu')
args.start_epoch = latest_checkpoint['epoch']
model.load_state_dict(latest_checkpoint['state_dict'])
optimizer.load_state_dict(latest_checkpoint['optimizer'])
scaler.load_state_dict(latest_checkpoint['scaler'])
best_acc1 = latest_checkpoint['best_acc1']
print("=> loaded latest checkpoint '{}' (epoch {})"
.format(latest, latest_checkpoint['epoch']))
cudnn.benchmark = True
# Data loading code
print("=> creating dataset")
tokenizer = SimpleTokenizer()
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.5, 1.0)),
transforms.ToTensor(),
normalize
])
val_transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize
])
val_dataset = datasets.ImageFolder(os.path.join(args.imagenet_root, 'val'), transform=val_transform)
if args.distributed:
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
else:
val_sampler = None
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False)
data = get_data(args, (train_transform, val_transform), tokenizer=tokenizer)
print('dataset size: %d' % data['train'].dataloader.num_samples)
train_loader = data['train'].dataloader
loader_len = train_loader.num_batches
lr_schedule = utils.cosine_scheduler(args.lr, args.lr_end, args.epochs,
loader_len // args.update_freq,
warmup_epochs=args.warmup_epochs, start_warmup_value=args.lr_start)
if utils.is_main_process() and args.output_dir is not None:
args.log_dir = os.path.join(args.output_dir, 'tb_logs')
os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=args.log_dir)
else:
log_writer = None
print(args)
print("=> beginning training")
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data['train'].set_epoch(epoch)
train_loader = data['train'].dataloader
# train for one epoch
train_stats = train(train_loader, log_writer, model, criterion, optimizer, scaler, epoch, lr_schedule, args)
val_stats = validate_zeroshot(val_loader, model, tokenizer, args)
acc1 = val_stats['acc1']
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)
print("=> saving checkpoint")
utils.save_on_master({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict(),
'scaler': scaler.state_dict(),
'best_acc1': best_acc1,
'args': args,
}, is_best, args.output_dir)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in val_stats.items()},
'epoch': epoch}
# log test stats to log_writer (tensorboard)
if log_writer is not None:
for k, v in log_stats.items():
if k.startswith('test'):
log_writer.add_scalar(k, v, epoch)
if utils.is_main_process():
with open(os.path.join(args.output_dir, 'log.txt'), 'a') as f:
f.write(json.dumps(log_stats) + '\n')
def train(train_loader, log_writer, model, criterion, optimizer, scaler, epoch, lr_schedule, args):
batch_time = AverageMeter('Time', ':6.2f')
data_time = AverageMeter('Data', ':6.2f')
mem = AverageMeter('Mem (GB)', ':6.1f')
metric_names = ['loss', 'clip_loss', 'clip_acc']
loader_len = train_loader.num_batches
iters_per_epoch = loader_len // args.update_freq
metrics = OrderedDict([(name, AverageMeter(name, ':.2e')) for name in metric_names])
progress = ProgressMeter(
iters_per_epoch,
[batch_time, data_time, mem, *metrics.values()],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for data_iter, inputs in enumerate(train_loader):
optim_iter = data_iter // args.update_freq
# measure data loading time
data_time.update(time.time() - end)
# update weight decay and learning rate according to their schedule
it = iters_per_epoch * epoch + optim_iter # global training iteration
for k, param_group in enumerate(optimizer.param_groups):
param_group['lr'] = lr_schedule[it]
inputs = [tensor.cuda(args.gpu, non_blocking=True) for tensor in inputs]
# compute output
with amp.autocast(enabled=not args.disable_amp):
outputs = model(*inputs)
loss_dict = criterion(outputs)
loss = loss_dict['loss']
loss /= args.update_freq
if not math.isfinite(loss.item()):
print("Loss is {}, stopping training".format(loss.item()))
sys.exit(1)
scaler.scale(loss).backward()
if (data_iter + 1) % args.update_freq != 0:
continue
# compute gradient and do SGD step
scaler.step(optimizer)
scaler.update()
model.zero_grad(set_to_none=True)
# clamp logit scale to [0, 100]
utils.get_model(model).logit_scale.data.clamp_(0, 4.6052)
logit_scale = utils.get_model(model).logit_scale.exp().item()
for k in loss_dict:
metrics[k].update(loss_dict[k].item(), args.batch_size)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
mem.update(torch.cuda.max_memory_allocated() // 1e9)
# save to log_writer (tensorboard)
if log_writer is not None:
for k, v in loss_dict.items():
log_writer.add_scalar(k, v.item(), it)
log_writer.add_scalar('scaler', scaler.get_scale(), it)
log_writer.add_scalar('logit', logit_scale, it)
log_writer.add_scalar('lr', optimizer.param_groups[0]['lr'], it)
if optim_iter % args.print_freq == 0:
progress.display(optim_iter)
progress.synchronize()
return {**{k: v.avg for k, v in metrics.items()},
'lr': optimizer.param_groups[0]['lr'],
'logit_scale': logit_scale}
def validate_zeroshot(val_loader, model, tokenizer, args):
batch_time = AverageMeter('Time', ':6.3f')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
len(val_loader),
[batch_time, top1, top5],
prefix='Test: ')
# switch to evaluate mode
model.eval()
print('=> encoding captions')
cwd = os.path.dirname(os.path.realpath(__file__))
templates = [
"itap of a {}.",
"a bad photo of the {}.",
"a origami {}.",
"a photo of the large {}.",
"a {} in a video game.",
"art of the {}.",
"a photo of the small {}."
]
with open(os.path.join(cwd, 'imagenet_labels.json')) as f:
labels = json.load(f)
with torch.no_grad():
text_features = []
for l in labels:
texts = [t.format(l) for t in templates]
texts = tokenizer(texts).cuda(args.gpu, non_blocking=True)
class_embeddings = utils.get_model(model).encode_text(texts)
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
class_embeddings = class_embeddings.mean(dim=0)
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
text_features.append(class_embeddings)
text_features = torch.stack(text_features, dim=0)
end = time.time()
for i, (images, target) in enumerate(val_loader):
images = images.cuda(args.gpu, non_blocking=True)
target = target.cuda(args.gpu, non_blocking=True)
# encode images
image_features = utils.get_model(model).encode_image(images)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_image = image_features @ text_features.t()
# measure accuracy and record loss
acc1, acc5 = accuracy(logits_per_image, target, topk=(1, 5))
acc1, acc5 = utils.scaled_all_reduce([acc1, acc5])
top1.update(acc1.item(), images.size(0))
top5.update(acc5.item(), images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
progress.synchronize()
print('0-shot * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return {'acc1': top1.avg, 'acc5': top5.avg}
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def synchronize(self):
if not utils.is_dist_avail_and_initialized():
return
t = torch.tensor([self.sum, self.count], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.sum = int(t[0])
self.count = t[1]
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def synchronize(self):
for meter in self.meters:
meter.synchronize()
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
if __name__ == '__main__':
parser = argparse.ArgumentParser('LaCLIP/CLIP training', parents=[get_args_parser()])
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
main(args)