-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathSimCLR.py
executable file
·152 lines (121 loc) · 6.11 KB
/
SimCLR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#SimCLR Code based on https://github.com/sthalles/SimCLR/blob/master/simclr.py
import os
import shutil
import torchvision
import yaml
import torch.nn as nn
import logging
import torch
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
torch.manual_seed(0)
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
def save_config_file(model_checkpoints_folder, args):
if not os.path.exists(model_checkpoints_folder):
os.makedirs(model_checkpoints_folder)
with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile:
yaml.dump(args, outfile, default_flow_style=False)
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class SimCLR(object):
def __init__(self, train_loader,epochs=200,batch_size=112):
self.backbone = torchvision.models.resnet18(pretrained=True)
dim_mlp = self.backbone.fc.in_features
# add mlp projection head
self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = self.backbone.to(self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
self.train_loader = train_loader
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=len(train_loader), eta_min=0,
last_epoch=-1)
self.writer = SummaryWriter()
logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'), level=logging.DEBUG)
self.criterion = torch.nn.CrossEntropyLoss().to(self.device)
# Should be the same with dataloaders batch size
self.batch_size = batch_size
self.n_views = 2
self.temperature = 0.5
self.epochs = epochs
def info_nce_loss(self, features):
labels = torch.cat([torch.arange(self.batch_size) for i in range(self.n_views)], dim=0)
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
labels = labels.to(self.device)
features = F.normalize(features, dim=1)
similarity_matrix = torch.matmul(features, features.T)
# discard the main diagonal from both: labels and similarities matrix
mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.device)
labels = labels[~mask].view(labels.shape[0], -1)
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
# select and combine multiple positives
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
# select only the negatives the negatives
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
logits = torch.cat([positives, negatives], dim=1)
labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device)
logits = logits / self.temperature
return logits, labels
def train(self):
train_loader = self.train_loader
scaler = GradScaler(enabled=False)
n_iter = 0
logging.info(f"Start SimCLR training for {self.epochs} epochs.")
logging.info(f"Training with gpu: {False}.")
for epoch_counter in range(self.epochs):
running_loss = 0
count = 0
for images, _ in tqdm(train_loader):
images = torch.cat(images, dim=0)
images = images.to(self.device)
with autocast(enabled=False):
features = self.model(images)
logits, labels = self.info_nce_loss(features)
loss = self.criterion(logits, labels)
running_loss += loss.mean().item()
count += 1
self.optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(self.optimizer)
scaler.update()
if n_iter % 20 == 0:
top1, top5 = accuracy(logits, labels, topk=(1, 5))
print('Epoch : ', epoch_counter, ' Iter : ', n_iter)
print('Loss : ', loss.mean().item())
print('acc/top1 : ', top1[0].item())
print('acc/top5 : ', top5[0].item())
self.writer.add_scalar('loss', loss, global_step=n_iter)
self.writer.add_scalar('acc/top1', top1[0], global_step=n_iter)
self.writer.add_scalar('acc/top5', top5[0], global_step=n_iter)
self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter)
n_iter += 1
print('Epoch ', epoch_counter, ' loss :', running_loss / count)
# warmup for the first 10 epochs
if epoch_counter >= 10:
self.scheduler.step()
logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}")
logging.info("Training has finished.")
# save model checkpoints
checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(self.epochs)
save_checkpoint({
'epoch': self.epochs,
'arch': 'ResNet18',
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
}, is_best=False, filename=os.path.join(self.writer.log_dir, checkpoint_name))
logging.info(f"Model checkpoint and metadata has been saved at {self.writer.log_dir}.")