-
Notifications
You must be signed in to change notification settings - Fork 166
/
linear_evaluation.py
207 lines (164 loc) · 6.34 KB
/
linear_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import os
import argparse
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from simclr import SimCLR
from simclr.modules import LogisticRegression, get_resnet
from simclr.modules.transformations import TransformsSimCLR
from utils import yaml_config_hook
def inference(loader, simclr_model, device):
feature_vector = []
labels_vector = []
for step, (x, y) in enumerate(loader):
x = x.to(device)
# get encoding
with torch.no_grad():
h, _, z, _ = simclr_model(x, x)
h = h.detach()
feature_vector.extend(h.cpu().detach().numpy())
labels_vector.extend(y.numpy())
if step % 20 == 0:
print(f"Step [{step}/{len(loader)}]\t Computing features...")
feature_vector = np.array(feature_vector)
labels_vector = np.array(labels_vector)
print("Features shape {}".format(feature_vector.shape))
return feature_vector, labels_vector
def get_features(simclr_model, train_loader, test_loader, device):
train_X, train_y = inference(train_loader, simclr_model, device)
test_X, test_y = inference(test_loader, simclr_model, device)
return train_X, train_y, test_X, test_y
def create_data_loaders_from_arrays(X_train, y_train, X_test, y_test, batch_size):
train = torch.utils.data.TensorDataset(
torch.from_numpy(X_train), torch.from_numpy(y_train)
)
train_loader = torch.utils.data.DataLoader(
train, batch_size=batch_size, shuffle=False
)
test = torch.utils.data.TensorDataset(
torch.from_numpy(X_test), torch.from_numpy(y_test)
)
test_loader = torch.utils.data.DataLoader(
test, batch_size=batch_size, shuffle=False
)
return train_loader, test_loader
def train(args, loader, simclr_model, model, criterion, optimizer):
loss_epoch = 0
accuracy_epoch = 0
for step, (x, y) in enumerate(loader):
optimizer.zero_grad()
x = x.to(args.device)
y = y.to(args.device)
output = model(x)
loss = criterion(output, y)
predicted = output.argmax(1)
acc = (predicted == y).sum().item() / y.size(0)
accuracy_epoch += acc
loss.backward()
optimizer.step()
loss_epoch += loss.item()
# if step % 100 == 0:
# print(
# f"Step [{step}/{len(loader)}]\t Loss: {loss.item()}\t Accuracy: {acc}"
# )
return loss_epoch, accuracy_epoch
def test(args, loader, simclr_model, model, criterion, optimizer):
loss_epoch = 0
accuracy_epoch = 0
model.eval()
for step, (x, y) in enumerate(loader):
model.zero_grad()
x = x.to(args.device)
y = y.to(args.device)
output = model(x)
loss = criterion(output, y)
predicted = output.argmax(1)
acc = (predicted == y).sum().item() / y.size(0)
accuracy_epoch += acc
loss_epoch += loss.item()
return loss_epoch, accuracy_epoch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SimCLR")
config = yaml_config_hook("./config/config.yaml")
for k, v in config.items():
parser.add_argument(f"--{k}", default=v, type=type(v))
args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.dataset == "STL10":
train_dataset = torchvision.datasets.STL10(
args.dataset_dir,
split="train",
download=True,
transform=TransformsSimCLR(size=args.image_size).test_transform,
)
test_dataset = torchvision.datasets.STL10(
args.dataset_dir,
split="test",
download=True,
transform=TransformsSimCLR(size=args.image_size).test_transform,
)
elif args.dataset == "CIFAR10":
train_dataset = torchvision.datasets.CIFAR10(
args.dataset_dir,
train=True,
download=True,
transform=TransformsSimCLR(size=args.image_size).test_transform,
)
test_dataset = torchvision.datasets.CIFAR10(
args.dataset_dir,
train=False,
download=True,
transform=TransformsSimCLR(size=args.image_size).test_transform,
)
else:
raise NotImplementedError
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.logistic_batch_size,
shuffle=True,
drop_last=True,
num_workers=args.workers,
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.logistic_batch_size,
shuffle=False,
drop_last=True,
num_workers=args.workers,
)
encoder = get_resnet(args.resnet, pretrained=False)
n_features = encoder.fc.in_features # get dimensions of fc layer
# load pre-trained model from checkpoint
simclr_model = SimCLR(encoder, args.projection_dim, n_features)
model_fp = os.path.join(args.model_path, "checkpoint_{}.tar".format(args.epoch_num))
simclr_model.load_state_dict(torch.load(model_fp, map_location=args.device.type))
simclr_model = simclr_model.to(args.device)
simclr_model.eval()
## Logistic Regression
n_classes = 10 # CIFAR-10 / STL-10
model = LogisticRegression(simclr_model.n_features, n_classes)
model = model.to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()
print("### Creating features from pre-trained context model ###")
(train_X, train_y, test_X, test_y) = get_features(
simclr_model, train_loader, test_loader, args.device
)
arr_train_loader, arr_test_loader = create_data_loaders_from_arrays(
train_X, train_y, test_X, test_y, args.logistic_batch_size
)
for epoch in range(args.logistic_epochs):
loss_epoch, accuracy_epoch = train(
args, arr_train_loader, simclr_model, model, criterion, optimizer
)
print(
f"Epoch [{epoch}/{args.logistic_epochs}]\t Loss: {loss_epoch / len(arr_train_loader)}\t Accuracy: {accuracy_epoch / len(arr_train_loader)}"
)
# final testing
loss_epoch, accuracy_epoch = test(
args, arr_test_loader, simclr_model, model, criterion, optimizer
)
print(
f"[FINAL]\t Loss: {loss_epoch / len(arr_test_loader)}\t Accuracy: {accuracy_epoch / len(arr_test_loader)}"
)