-
Notifications
You must be signed in to change notification settings - Fork 7
/
barlow_twins.py
103 lines (87 loc) · 2.78 KB
/
barlow_twins.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import lightly.loss as loss
import lightly.models as models
import pytorch_lightning as pl
import torch
import torchvision
from PIL import ImageFile
from lightly.models.modules.heads import ProjectionHead
from torch import nn
from data.data_ukb import get_imaging_pretraining_data
torch.multiprocessing.set_sharing_strategy("file_system")
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
ImageFile.LOAD_TRUNCATED_IMAGES = True
max_epochs = 100
IMG_SIZE = 448
PROJECTION_DIM = 128
BATCH_SIZE = 32
ACCUMULATE_GRAD_BATCHES = 2
LR = 1e-3
WEIGHT_DECAY = 1e-6
TEMPERATURE = 0.1
class BarlowTwinsModel(pl.LightningModule):
def __init__(self, num_ftrs=2048):
super().__init__()
# create a ResNet backbone and remove the classification head
resnet = torchvision.models.resnet50()
# create a byol model based on ResNet
self.resnet_barlow_twins = models.BarlowTwins(
torch.nn.Sequential(*list(resnet.children())[:-1]),
num_ftrs=num_ftrs,
proj_hidden_dim=num_ftrs,
out_dim=PROJECTION_DIM,
)
self.resnet_barlow_twins.projection_mlp = ProjectionHead(
[
(
self.resnet_barlow_twins.num_ftrs,
self.resnet_barlow_twins.proj_hidden_dim,
nn.BatchNorm1d(self.resnet_barlow_twins.proj_hidden_dim),
nn.ReLU(inplace=True),
),
(
self.resnet_barlow_twins.proj_hidden_dim,
self.resnet_barlow_twins.out_dim,
None,
None,
),
]
)
self.criterion = loss.BarlowTwinsLoss()
def forward(self, x):
self.resnet_barlow_twins(x)
def training_step(self, batch, batch_idx):
(x0, x1), _, _ = batch
x0, x1 = self.resnet_barlow_twins(x0, x1)
loss = self.criterion(x0, x1)
self.log("train_loss_ssl", loss)
return loss
def configure_optimizers(self):
global training_set_len
optim = torch.optim.Adam(
self.resnet_barlow_twins.parameters(),
LR,
weight_decay=WEIGHT_DECAY,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optim, T_max=training_set_len, eta_min=0, last_epoch=-1
)
return [optim], [scheduler]
model = BarlowTwinsModel()
print(model)
dataloader, _, _ = get_imaging_pretraining_data(
num_workers=8,
size=IMG_SIZE,
batch_size=BATCH_SIZE,
train_pct=0.7,
val_pct=0.1,
tfms_settings="simclr",
)
training_set_len = len(dataloader)
trainer = pl.Trainer(
max_epochs=max_epochs,
gpus=1,
accumulate_grad_batches=ACCUMULATE_GRAD_BATCHES,
)
trainer.fit(model, dataloader)
print("Finished Training")