-
Notifications
You must be signed in to change notification settings - Fork 2
/
run_imagenet.py
129 lines (108 loc) · 5.19 KB
/
run_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import numpy as np
import torch
import clip
import torchvision
import torch.nn.functional as F
import torch.nn as nn
from tqdm import tqdm
import argparse, os, yaml
from class_template import TEMPLATE, CLASS_NAME
from utils import accuracy, text_encode
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--config', dest='config', help='settings of CALIP in yaml format')
args = parser.parse_args()
print(args)
return args
def main(cfg):
backbone = cfg['backbone']
total_feat_path = os.path.join('cache', 'total_features', backbone)
label_path = os.path.join('cache', 'label', backbone)
os.makedirs(total_feat_path, exist_ok=True)
os.makedirs(label_path, exist_ok=True)
clip.available_models()
model, preprocess = clip.load(backbone)
model.eval()
print(f"Loading {cfg['dataset']} and templates for CALIP: {len(CLASS_NAME[cfg['dataset']])} classes, {len(TEMPLATE[cfg['dataset']])} templates")
dataset = torchvision.datasets.ImageNet(cfg['data_root'] + cfg['dataset'], split='val', transform=preprocess)
loader = torch.utils.data.DataLoader(dataset, batch_size=128, num_workers=8, shuffle=False)
print('Encoding text features...')
feat_t = text_encode(CLASS_NAME[cfg['dataset']], TEMPLATE[cfg['dataset']], model)
print('Finish encoding text features.')
if cfg['load_cache']:
print('Loading cached image features and labels from ./cache/...')
total_features = torch.load(total_feat_path + '/' + cfg['dataset'] + '.pt')
labels = torch.load(label_path + '/' + cfg['dataset'] + '.pt')
else:
print('No cached features and labels, start encoding image features with clip...')
total_features = []
labels = []
with torch.no_grad():
for i, (images, label) in enumerate(tqdm(loader)):
images = images.cuda()
label = label.cuda()
features = model.encode_image(images)
features = features.permute(1, 0, 2)
features /= features.norm(dim=-1, keepdim=True)
total_features.append(features)
labels.append(label)
total_features = torch.cat(total_features, dim=0)
labels = torch.cat(labels, dim=0)
torch.save(total_features, total_feat_path + '/' + cfg['dataset'] + '.pt')
torch.save(labels, label_path + '/' + cfg['dataset'] + '.pt')
img_global_feat = total_features[:, 0, :]
img_spatial_feat = total_features[:, 1: , :]
img_spatial_feat = img_spatial_feat.permute(0, 2, 1)
# ------------------------------------------ CLIP Zero-shot ------------------------------------------
logits = 100. * img_global_feat @ feat_t
acc, _ = accuracy(logits, labels, n=img_global_feat.size(0))
print(f"CLIP zero-shot accuracy: {acc:.2f}")
# ------------------------------------------ CALIP Zero-shot -----------------------------------------
def get_logits():
with torch.no_grad():
logits1 = []
logits2 = []
for i, feat_v in enumerate(tqdm(img_spatial_feat)):
A_weight = torch.matmul(feat_v.permute(1, 0), feat_t) * 2
A_weight1 = F.softmax(A_weight, dim=0)
A_weight2 = F.softmax(A_weight, dim=1)
feat_t_a = torch.matmul(feat_v, A_weight1)
feat_v_a = torch.matmul(A_weight2, feat_t.permute(1, 0))
feat_v_a = feat_v_a.mean(0) + feat_v_a.max(0)[0]
l1 = 100. * img_global_feat[i] @ feat_t_a
l2 = 100. * feat_v_a @ feat_t
logits1.append(l1.unsqueeze(0))
logits2.append(l2.unsqueeze(0))
logits1 = torch.cat(logits1, dim=0)
logits2 = torch.cat(logits2, dim=0)
return logits1, logits2
if cfg['search']:
logits1, logits2 = get_logits()
beta2_list = [i * (cfg['beta2'] - 0.001) / 200 + 0.001 for i in range(200)]
beta3_list = [i * (cfg['beta3'] - 0.001) / 200 + 0.001 for i in range(200)]
print('-' * 20)
print('Starting searching...')
print(' beta1 = 1.0')
print(' beta2 searching range: [0.001, ' + str(cfg['beta2']) + ']')
print(' beta3 searching range: [0.001, ' + str(cfg['beta3']) + ']')
print('-' * 20)
best_acc = 0.
best_beta2 = 0.
best_beta3 = 0.
for beta2 in beta2_list:
for beta3 in beta3_list:
logits = 100. * img_global_feat @ feat_t
logits = logits + logits1 * beta2 + logits2 * beta3
acc, _ = accuracy(logits, labels, n=img_global_feat.size(0))
if acc > best_acc:
print('New best setting, beta1: {:.4f}; beta2: {:.4f}; beta3: {:.4f}; Acc: {:.2f}'.format(1, beta2, beta3, acc))
best_acc = acc
best_beta2 = beta2
best_beta3 = beta3
print(f"Finish searching {cfg['dataset']} on backbone {cfg['backbone']}. Final Acc: {best_acc:.2f}")
if __name__ == '__main__':
args = get_arguments()
assert (os.path.exists(args.config))
cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
print(cfg)
main(cfg)