-
Notifications
You must be signed in to change notification settings - Fork 12
/
demo.py
132 lines (93 loc) · 3.42 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import sys
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from scipy.special import softmax
import torch
from torch import nn, optim
from torch.nn import functional as F
import torch.nn.init as init
import torchvision
import torchvision as tv
import torchvision.transforms as transforms
sys.path.insert(1, 'models/')
import resnet
import metrics
import recalibration
import visualization
np.random.seed(0)
PATH = './pretrained_models/cifar10_resnet20.pth'
net_trained = resnet.ResNet(resnet.BasicBlock, [3, 3, 3])
net_trained.load_state_dict(torch.load(PATH,map_location=torch.device('cpu')))
# Data transforms
mean = [0.5071, 0.4867, 0.4408]
stdv = [0.2675, 0.2565, 0.2761]
test_transforms = tv.transforms.Compose([
tv.transforms.ToTensor(),
tv.transforms.Normalize(mean=mean, std=stdv),
])
# IMPORTANT! We need to use the same validation set for temperature
# scaling, so we're going to save the indices for later
test_set = tv.datasets.CIFAR10(root='./data', train=False, transform=test_transforms, download=True)
testloader = torch.utils.data.DataLoader(test_set, pin_memory=True, batch_size=4)
correct = 0
total = 0
# First: collect all the logits and labels for the validation set
logits_list = []
labels_list = []
with torch.no_grad():
for images, labels in testloader:
#outputs are the the raw scores!
logits = net_trained(images)
#add data to list
logits_list.append(logits)
labels_list.append(labels)
#convert to probabilities
output_probs = F.softmax(logits,dim=1)
#get predictions from class
probs, predicted = torch.max(output_probs, 1)
#total
total += labels.size(0)
correct += (predicted == labels).sum().item()
logits = torch.cat(logits_list)
labels = torch.cat(labels_list)
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
print(total)
################
#metrics
ece_criterion = metrics.ECELoss()
#Torch version
logits_np = logits.numpy()
labels_np = labels.numpy()
#Numpy Version
print('ECE: %f' % (ece_criterion.loss(logits_np,labels_np, 15)))
softmaxes = softmax(logits_np, axis=1)
print('ECE with probabilties %f' % (ece_criterion.loss(softmaxes,labels_np,15,False)))
mce_criterion = metrics.MCELoss()
print('MCE: %f' % (mce_criterion.loss(logits_np,labels_np)))
oe_criterion = metrics.OELoss()
print('OE: %f' % (oe_criterion.loss(logits_np,labels_np)))
sce_criterion = metrics.SCELoss()
print('SCE: %f' % (sce_criterion.loss(logits_np,labels_np, 15)))
ace_criterion = metrics.ACELoss()
print('ACE: %f' % (ace_criterion.loss(logits_np,labels_np,15)))
tace_criterion = metrics.TACELoss()
threshold = 0.01
print('TACE (threshold = %f): %f' % (threshold, tace_criterion.loss(logits_np,labels_np,threshold,15)))
############
#recalibration
model = recalibration.ModelWithTemperature(net_trained)
# Tune the model temperature, and save the results
model.set_temperature(testloader)
############
#visualizations
conf_hist = visualization.ConfidenceHistogram()
plt_test = conf_hist.plot(logits_np,labels_np,title="Confidence Histogram")
plt_test.savefig('plots/conf_histogram_test.png',bbox_inches='tight')
#plt_test.show()
rel_diagram = visualization.ReliabilityDiagram()
plt_test_2 = rel_diagram.plot(logits_np,labels_np,title="Reliability Diagram")
plt_test_2.savefig('plots/rel_diagram_test.png',bbox_inches='tight')
#plt_test_2.show()