-
Notifications
You must be signed in to change notification settings - Fork 198
/
Pretrain.py
203 lines (154 loc) · 7.58 KB
/
Pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
'''
* Copyright (c) 2021, salesforce.com, inc.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
'''
import argparse
import os
import ruamel_yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from models.model_pretrain import ALBEF
from models.vit import interpolate_pos_embed
from models.tokenization_bert import BertTokenizer
import utils
from dataset import create_dataset, create_sampler, create_loader
from scheduler import create_scheduler
from optim import create_optimizer
def train(model, data_loader, optimizer, tokenizer, epoch, warmup_steps, device, scheduler, config):
# train
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
metric_logger.add_meter('loss_mlm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
header = 'Train Epoch: [{}]'.format(epoch)
print_freq = 50
step_size = 100
warmup_iterations = warmup_steps*step_size
if args.distributed:
data_loader.sampler.set_epoch(epoch)
for i, (image, text) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
optimizer.zero_grad()
image = image.to(device,non_blocking=True)
text_input = tokenizer(text, padding='longest', truncation=True, max_length=25, return_tensors="pt").to(device)
if epoch>0:
alpha = config['alpha']
else:
alpha = config['alpha']*min(1,i/len(data_loader))
loss_mlm, loss_ita, loss_itm = model(image, text_input, alpha = alpha)
loss = loss_mlm + loss_ita + loss_itm
loss.backward()
optimizer.step()
metric_logger.update(loss_mlm=loss_mlm.item())
metric_logger.update(loss_ita=loss_ita.item())
metric_logger.update(loss_itm=loss_itm.item())
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
if epoch==0 and i%step_size==0 and i<=warmup_iterations:
scheduler.step(i//step_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger.global_avg())
return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
def main(args, config):
utils.init_distributed_mode(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
start_epoch = 0
max_epoch = config['schedular']['epochs']
warmup_steps = config['schedular']['warmup_epochs']
#### Dataset ####
print("Creating dataset")
datasets = [create_dataset('pretrain', config)]
if args.distributed:
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
samplers = create_sampler(datasets, [True], num_tasks, global_rank)
else:
samplers = [None]
data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0]
tokenizer = BertTokenizer.from_pretrained(args.text_encoder)
#### Model ####
print("Creating model")
model = ALBEF(config=config, text_encoder=args.text_encoder, tokenizer=tokenizer, init_deit=True)
model = model.to(device)
arg_opt = utils.AttrDict(config['optimizer'])
optimizer = create_optimizer(arg_opt, model)
arg_sche = utils.AttrDict(config['schedular'])
lr_scheduler, _ = create_scheduler(arg_sche, optimizer)
if args.checkpoint:
checkpoint = torch.load(args.checkpoint, map_location='cpu')
state_dict = checkpoint['model']
if args.resume:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
start_epoch = checkpoint['epoch']+1
else:
pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
m_pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],model.visual_encoder_m)
state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped
state_dict['visual_encoder_m.pos_embed'] = m_pos_embed_reshaped
model.load_state_dict(state_dict)
print('load checkpoint from %s'%args.checkpoint)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
print("Start training")
start_time = time.time()
for epoch in range(start_epoch, max_epoch):
if epoch>0:
lr_scheduler.step(epoch+warmup_steps)
train_stats = train(model, data_loader, optimizer, tokenizer, epoch, warmup_steps, device, lr_scheduler, config)
if utils.is_main_process():
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch,
}
save_obj = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'config': config,
'epoch': epoch,
}
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
f.write(json.dumps(log_stats) + "\n")
dist.barrier()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='./configs/Pretrain.yaml')
parser.add_argument('--checkpoint', default='')
parser.add_argument('--resume', default=False, type=bool)
parser.add_argument('--output_dir', default='Pretrain/')
parser.add_argument('--text_encoder', default='bert-base-uncased')
parser.add_argument('--device', default='cuda')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--distributed', default=True, type=bool)
args = parser.parse_args()
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
main(args, config)