-
Notifications
You must be signed in to change notification settings - Fork 274
/
finetune.py
291 lines (211 loc) · 11.2 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
'''
* RAM++ & RAM & Tag2Text finetune
* Written by Xinyu Huang
'''
import argparse
import os
import ruamel.yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.utils.data import DataLoader
from ram.models import ram_plus, ram, tag2text
import utils
from utils import cosine_lr_schedule
from ram.data import create_dataset, create_sampler, create_loader
import clip
def build_text_embed(model_clip, caption):
run_on_gpu = torch.cuda.is_available()
with torch.no_grad():
texts = clip.tokenize(caption,truncate = True) # tokenize
if run_on_gpu:
texts = texts.cuda()
model_clip = model_clip.cuda()
text_embeddings = model_clip.encode_text(texts)
text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
return text_embeddings
def train_ram_plus(model, data_loader, optimizer, epoch, device, config, model_clip):
# train
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
metric_logger.add_meter('loss_tag', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
metric_logger.add_meter('loss_dis', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
metric_logger.add_meter('loss_alignment', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
header = 'Train Epoch: [{}]'.format(epoch)
print_freq = 50
data_loader.sampler.set_epoch(epoch)
for i, (image, image_224, caption, image_tag, parse_tag) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
optimizer.zero_grad()
batch_text_embed = build_text_embed(model_clip,caption)
image = image.to(device,non_blocking=True)
image_224 = image_224.to(device,non_blocking=True)
with torch.no_grad():
clip_image_feature = model_clip.encode_image(image_224)
loss_tag, loss_dis, loss_alignment = model(image, caption, image_tag, clip_image_feature, batch_text_embed)
loss = loss_tag + loss_dis + loss_alignment
loss.backward()
optimizer.step()
metric_logger.update(loss_tag=loss_tag.item())
metric_logger.update(loss_dis=loss_dis.item())
metric_logger.update(loss_alignment=loss_alignment.item())
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger.global_avg())
return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
def train_ram(model, data_loader, optimizer, epoch, device, config, model_clip):
# train
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
metric_logger.add_meter('loss_t2t', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
metric_logger.add_meter('loss_tag', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
metric_logger.add_meter('loss_dis', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
header = 'Train Epoch: [{}]'.format(epoch)
print_freq = 50
data_loader.sampler.set_epoch(epoch)
for i, (image, image_224, caption, image_tag, parse_tag) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
optimizer.zero_grad()
image = image.to(device,non_blocking=True)
image_224 = image_224.to(device,non_blocking=True)
with torch.no_grad():
clip_image_feature = model_clip.encode_image(image_224)
loss_t2t, loss_tag, loss_dis = model(image, caption, image_tag, parse_tag, clip_image_feature)
loss = loss_t2t + loss_tag/(loss_tag/loss_t2t).detach() + loss_dis
loss.backward()
optimizer.step()
metric_logger.update(loss_t2t=loss_t2t.item())
metric_logger.update(loss_tag=loss_tag.item())
metric_logger.update(loss_dis=loss_dis.item())
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger.global_avg())
return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
def train_tag2text(model, data_loader, optimizer, epoch, device, config):
# train
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
metric_logger.add_meter('loss_t2t', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
metric_logger.add_meter('loss_tag', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
header = 'Train Epoch: [{}]'.format(epoch)
print_freq = 50
data_loader.sampler.set_epoch(epoch)
for i, (image, _, caption, _, parse_tag) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
optimizer.zero_grad()
image = image.to(device,non_blocking=True)
loss_t2t, loss_tag = model(image, caption, parse_tag)
loss = loss_t2t + loss_tag/(loss_tag/loss_t2t).detach()
loss.backward()
optimizer.step()
metric_logger.update(loss_t2t=loss_t2t.item())
metric_logger.update(loss_tag=loss_tag.item())
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger.global_avg())
return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
def main(args, config):
utils.init_distributed_mode(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
#### Dataset ####
print("Creating dataset")
datasets = [create_dataset('finetune', config, min_scale=0.2)]
print('number of training samples: %d'%len(datasets[0]))
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
samplers = create_sampler(datasets, [True], num_tasks, global_rank)
data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0]
print("Creating model")
if args.checkpoint:
print("load from:", args.checkpoint)
#### Model ####
if args.model_type == 'ram_plus':
print("Creating pretrained CLIP model")
model_clip, _ = clip.load("ViT-B/16", device=device)
print("Creating RAM model")
model = ram_plus(pretrained = args.checkpoint,image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'],
vit_ckpt_layer=config['vit_ckpt_layer'])
elif args.model_type == 'ram':
print("Creating pretrained CLIP model")
model_clip, _ = clip.load("ViT-B/16", device=device)
print("Creating RAM model")
model = ram(pretrained = args.checkpoint,image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'],
vit_ckpt_layer=config['vit_ckpt_layer'])
elif args.model_type == 'tag2text':
print("Creating Tag2Text model")
model = tag2text(pretrained = args.checkpoint,image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'],
vit_ckpt_layer=config['vit_ckpt_layer'], tag_list='ram/data/ram_tag_list.txt')
model = model.to(device)
### Frozen CLIP model ###
model_clip = model_clip.to(device)
for _, param in model_clip.named_parameters():
param.requires_grad = False
### Frozen label embedding for open-set recogniztion ###
model.label_embed.requires_grad = False
optimizer = torch.optim.AdamW(filter(lambda x: x.requires_grad, model.parameters()), lr=config['init_lr'], weight_decay=config['weight_decay'])
start_epoch = 0
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
print("Start training")
start_time = time.time()
for epoch in range(start_epoch, config['max_epoch']):
cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
if args.model_type == 'ram_plus':
train_stats = train_ram_plus(model, data_loader, optimizer, epoch, device, config, model_clip)
elif args.model_type == 'ram':
train_stats = train_ram(model, data_loader, optimizer, epoch, device, config, model_clip)
elif args.model_type == 'tag2text':
train_stats = train_tag2text(model, data_loader, optimizer, epoch, device, config)
if utils.is_main_process():
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch,
}
save_obj = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'config': config,
'epoch': epoch,
}
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
f.write(json.dumps(log_stats) + "\n")
dist.barrier()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='./configs/pretrain.yaml')
parser.add_argument("--model-type",type=str,choices=("ram_plus", "ram", "tag2text"),required=True)
parser.add_argument('--output-dir', default='output/Pretrain')
parser.add_argument('--checkpoint', default='')
parser.add_argument('--evaluate', action='store_true')
parser.add_argument('--device', default='cuda')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--distributed', default=True, type=bool)
args = parser.parse_args()
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
main(args, config)