-
Notifications
You must be signed in to change notification settings - Fork 135
/
simsiam.py
151 lines (115 loc) · 4.05 KB
/
simsiam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50
def D(p, z, version='simplified'): # negative cosine similarity
if version == 'original':
z = z.detach() # stop gradient
p = F.normalize(p, dim=1) # l2-normalize
z = F.normalize(z, dim=1) # l2-normalize
return -(p*z).sum(dim=1).mean()
elif version == 'simplified':# same thing, much faster. Scroll down, speed test in __main__
return - F.cosine_similarity(p, z.detach(), dim=-1).mean()
else:
raise Exception
class projection_MLP(nn.Module):
def __init__(self, in_dim, hidden_dim=2048, out_dim=2048):
super().__init__()
''' page 3 baseline setting
Projection MLP. The projection MLP (in f) has BN ap-
plied to each fully-connected (fc) layer, including its out-
put fc. Its output fc has no ReLU. The hidden fc is 2048-d.
This MLP has 3 layers.
'''
self.layer1 = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True)
)
self.layer2 = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True)
)
self.layer3 = nn.Sequential(
nn.Linear(hidden_dim, out_dim),
nn.BatchNorm1d(hidden_dim)
)
self.num_layers = 3
def set_layers(self, num_layers):
self.num_layers = num_layers
def forward(self, x):
if self.num_layers == 3:
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
elif self.num_layers == 2:
x = self.layer1(x)
x = self.layer3(x)
else:
raise Exception
return x
class prediction_MLP(nn.Module):
def __init__(self, in_dim=2048, hidden_dim=512, out_dim=2048): # bottleneck structure
super().__init__()
''' page 3 baseline setting
Prediction MLP. The prediction MLP (h) has BN applied
to its hidden fc layers. Its output fc does not have BN
(ablation in Sec. 4.4) or ReLU. This MLP has 2 layers.
The dimension of h’s input and output (z and p) is d = 2048,
and h’s hidden layer’s dimension is 512, making h a
bottleneck structure (ablation in supplement).
'''
self.layer1 = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True)
)
self.layer2 = nn.Linear(hidden_dim, out_dim)
"""
Adding BN to the output of the prediction MLP h does not work
well (Table 3d). We find that this is not about collapsing.
The training is unstable and the loss oscillates.
"""
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
return x
class SimSiam(nn.Module):
def __init__(self, backbone=resnet50()):
super().__init__()
self.backbone = backbone
self.projector = projection_MLP(backbone.output_dim)
self.encoder = nn.Sequential( # f encoder
self.backbone,
self.projector
)
self.predictor = prediction_MLP()
def forward(self, x1, x2):
f, h = self.encoder, self.predictor
z1, z2 = f(x1), f(x2)
p1, p2 = h(z1), h(z2)
L = D(p1, z2) / 2 + D(p2, z1) / 2
return {'loss': L}
if __name__ == "__main__":
model = SimSiam()
x1 = torch.randn((2, 3, 224, 224))
x2 = torch.randn_like(x1)
model.forward(x1, x2).backward()
print("forward backwork check")
z1 = torch.randn((200, 2560))
z2 = torch.randn_like(z1)
import time
tic = time.time()
print(D(z1, z2, version='original'))
toc = time.time()
print(toc - tic)
tic = time.time()
print(D(z1, z2, version='simplified'))
toc = time.time()
print(toc - tic)
# Output:
# tensor(-0.0010)
# 0.005159854888916016
# tensor(-0.0010)
# 0.0014872550964355469