-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_classifier.py
151 lines (120 loc) · 5.86 KB
/
test_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import sys
import argparse
from tqdm import tqdm
import torch
import torchvision
import numpy as np
from vit.src.model import VisionTransformer as ViT
from vit.src.utils import MetricTracker, accuracy
from utils.ood_detection.ood_detector import MiniNet, CNN_IBP
from utils.load_data import get_test_dataloader
from utils.store_model import load_classifier
def get_model_from_args(args, model_name, num_classes):
"""
get_model_from_args loads a classifier model from the specified arguments dotdict
:args: dotdict containing all the arguments
:model_name: string stating what kind of model should be loaded as a classifier
:num_classes: integer specifying how many classes the classifier should be able to detect
:return: loaded classifier model
"""
if model_name.lower() == "resnet":
model = torchvision.models.resnet18(pretrained=False, num_classes=num_classes).to(device=args.device) # cuda()
elif model_name.lower() == "vit":
model = ViT(image_size=(args.img_size, args.img_size), # 224,224
num_heads=args.num_heads, #12 #also a very small amount of heads to speed up training
num_layers=args.num_layers, # 12 # 5 is a very small vit model
num_classes=num_classes, # 2 for OOD detection, 10 or more for classification
contrastive=False,
timm=True).to(device=args.device) # cuda()
"""
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
inputs = feature_extractor(image, return_tensors="pt")
"""
elif model_name.lower() == "cnn_ibp":
#TODO currently not working, throwing error
#"RuntimeError: mat1 dim 1 must match mat2 dim 0"
model = CNN_IBP().to(device=args.device)
elif model_name.lower() == "mininet":
model = MiniNet().to(device=args.device)
else:
raise ValueError("Error - wrong model specified in 'args'")
return model
def test_classifier(args):
"""
test_classifier runs one epoch of all test samples to evaluate the classifiers' performance.
:args: dotdict containing all the arguments
"""
classifier = load_classifier(args)
# metric tracker
metric_names = ['loss', 'acc1', 'acc5']
metrics = MetricTracker(*[metric for metric in metric_names], writer=None)
log = {}
losses = []
acc1s = []
acc5s = []
# get a dataloader mixed 50:50 with ID and OOD data and labels of 0 (ID) and 1 (OOD)
test_dataloader = get_test_dataloader(args)
with torch.no_grad():
classifier.eval()
criterion = torch.nn.CrossEntropyLoss().to(args.device) # CHANGE appended to device and placed outside of loop
running_test_error = 0
for epoch_nr, (inputs, labels) in enumerate(tqdm(test_dataloader)):
# from ViT training validation
metrics.reset()
inputs, labels = inputs.to(device=args.device), labels.to(device=args.device)
outputs = classifier(inputs)
running_test_error += error_criterion(outputs.squeeze(1), labels)
if args.device == "cuda": torch.cuda.empty_cache()
loss = criterion(outputs, labels)
acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))
losses.append(loss.item())
acc1s.append(acc1.item())
acc5s.append(acc5.item())
loss = np.mean(losses)
acc1 = np.mean(acc1s)
acc5 = np.mean(acc5s)
if metrics.writer is not None:
metrics.writer.set_step(epoch_nr, 'valid')
metrics.update('loss', loss)
metrics.update('acc1', acc1)
metrics.update('acc5', acc5)
log.update(**{'val_' + k: v for k, v in metrics.result().items()})
# print logged informations to the screen
for key, value in log.items():
print(' {:15s}: {}'.format(str(key), value))
# Error
avg_valid_error = running_test_error / (epoch_nr + 1)
print("\nOld metrics")
print("Average Test Error:", avg_valid_error.item())
print("Finished Testing the Model")
def error_criterion(outputs, labels):
"""
error_criterion used to calculate the errors in the validation phase
:outputs: batch with all the model outputs
:labels: ground truth labels the model should have optimally predicted
:return: numerical value as the error
"""
prediction_tensor = torch.max(outputs, dim=1) #torch.where(outputs > 0., 1., 0.)
train_error = (prediction_tensor.indices != labels).float().sum() / prediction_tensor.indices.size()[0]
return train_error
def parse_args():
"""
parse_args retrieves the arguments from the command line and parses them into the arguments dotdict.
:return: dotdict with all the arguments
"""
parser = argparse.ArgumentParser(description='Run the monotone PGD attack on a batch of images, default is with ViT and the MPGD of Alex, where cifar10 is ID and cifar100 is OOD')
parser.add_argument('--model', type=str, default="vit", help='str - what model should be used to classify input samples "vit", "resnet", "mininet" or "cnn_ibp"')
parser.add_argument('--classification_ckpt', type=str, default=None, help='str - path of pretrained model checkpoint')
parser.add_argument('--device', type=str, default="cuda", help='str - cpu or cuda to calculate the tensors on')
parser.add_argument("--num-workers", type=int, default=8, help="number of workers")
parser.add_argument("--n-gpu", type=int, default=2, help="number of gpus to use")
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
gettrace = getattr(sys, 'gettrace', None)
if gettrace():
print("num_workers is set to 0 in debugging mode, otherwise no useful debugging possible")
args.num_workers = 0
test_classifier(args)
print("finished all executions")