-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_all_imagenet.py
116 lines (98 loc) · 4.66 KB
/
test_all_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import torch
from tqdm import tqdm
import data_loader.data_loaders as module_data
from data_loader.imagenet_lt_data_loaders import ImageNetLTDataLoader
import model.loss as module_loss
import model.metric as module_metric
import model.model as module_arch
import numpy as np
from parse_config import ConfigParser
import torch.nn.functional as F
def main(config):
logger = config.get_logger('test')
# build model architecture
if 'returns_feat' in config['arch']['args']:
model = config.init_obj('arch', module_arch, allow_override=True, returns_feat=False)
else:
model = config.init_obj('arch', module_arch)
#logger.info(model)
logger.info('Loading checkpoint: {} ...'.format(config.resume))
checkpoint = torch.load(config.resume)
state_dict = checkpoint['state_dict']
if config['n_gpu'] > 1:
model = torch.nn.DataParallel(model)
model.load_state_dict(state_dict)
# prepare model for testing
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()
num_classes = config._config["arch"]["args"]["num_classes"]
record_list=[]
test_distribution_set = ["forward50", "forward25", "forward10", "forward5", "forward2", "uniform", "backward2", "backward5", "backward10", "backward25", "backward50"]
for test_distribution in test_distribution_set:
test_txt = '/ImageNet_LT_%s.txt'%(test_distribution)
print(test_txt)
data_loader = ImageNetLTDataLoader(
config['data_loader']['args']['data_dir'],
batch_size=128,
shuffle=False,
training=False,
num_workers=2,
test_txt=test_txt
)
record = validation(data_loader, model, num_classes,device)
record_list.append(record)
print('='*25, ' Final results ', '='*25)
i = 0
for txt in record_list:
print(test_distribution_set[i]+'\t')
print(*txt)
i+=1
def mic_acc_cal(preds, labels):
if isinstance(labels, tuple):
assert len(labels) == 3
targets_a, targets_b, lam = labels
acc_mic_top1 = (lam * preds.eq(targets_a.data).cpu().sum().float() \
+ (1 - lam) * preds.eq(targets_b.data).cpu().sum().float()) / len(preds)
else:
acc_mic_top1 = (preds == labels).sum().item() / len(labels)
return acc_mic_top1
def validation(data_loader, model, num_classes,device):
b = np.load("./data/imagenet_lt_shot_list.npy")
many_shot = b[0]
medium_shot = b[1]
few_shot = b[2]
confusion_matrix = torch.zeros(num_classes, num_classes).cuda()
total_logits = torch.empty((0, num_classes)).cuda()
total_labels = torch.empty(0, dtype=torch.long).cuda()
with torch.no_grad():
for i, (data, target) in enumerate(tqdm(data_loader)):
data, target = data.to(device), target.to(device)
output = model(data)
for t, p in zip(target.view(-1), output.argmax(dim=1).view(-1)):
confusion_matrix[t.long(), p.long()] += 1
total_logits = torch.cat((total_logits, output))
total_labels = torch.cat((total_labels, target))
probs, preds = F.softmax(total_logits.detach(), dim=1).max(dim=1)
# Calculate the overall accuracy and F measurement
eval_acc_mic_top1= mic_acc_cal(preds[total_labels != -1],
total_labels[total_labels != -1])
print('All top-1 Acc:', np.round(eval_acc_mic_top1 * 100, decimals=2))
acc_per_class = confusion_matrix.diag()/confusion_matrix.sum(1)
acc = acc_per_class.cpu().numpy()
many_shot_acc = acc[many_shot].mean()
medium_shot_acc = acc[medium_shot].mean()
few_shot_acc = acc[few_shot].mean()
print("{}, {}, {}".format(np.round(many_shot_acc * 100, decimals=2), np.round(medium_shot_acc * 100, decimals=2), np.round(few_shot_acc * 100, decimals=2)))
return np.round(many_shot_acc * 100, decimals=2), np.round(medium_shot_acc * 100, decimals=2), np.round(few_shot_acc * 100, decimals=2), np.round(eval_acc_mic_top1 * 100, decimals=2)
if __name__ == '__main__':
args = argparse.ArgumentParser(description='PyTorch Template')
args.add_argument('-c', '--config', default=None, type=str,
help='config file path (default: None)')
args.add_argument('-r', '--resume', default=None, type=str,
help='path to latest checkpoint (default: None)')
args.add_argument('-d', '--device', default=None, type=str,
help='indices of GPUs to enable (default: all)')
config = ConfigParser.from_args(args)
main(config)