-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtest.py
92 lines (79 loc) · 3.74 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import argparse
import os
import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from trainer import ClassifierTrainer
from utils import get_config, check_dir, get_local_time
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='configs/GTSRB.yaml', help="net configuration")
parser.add_argument('-i', '--input_dir', type=str,
default='GTSRB-new\RB\val\0-clear',
help="input image path")
parser.add_argument('-o', '--output_dir', type=str, default='result-2/GTSRB/original/',
help="output image path")
parser.add_argument('-p', '--checkpoint', type=str, default='checkpoints-new-2/0-0.2/outputs/GTSRB/checkpoints/classifier.pt',
help="checkpoint")
parser.add_argument('-l', '--log_name', type=str, default='0-0.2.log', help="log name")
parser.add_argument('-g', '--gpu_id', type=int, default=0, help="gpu id")
opts = parser.parse_args()
# Load experiment setting
config = get_config(opts.config)
# Setup model and data loader
trainer = ClassifierTrainer(config)
state_dict = torch.load(opts.checkpoint, map_location='cuda:{}'.format(opts.gpu_id))
trainer.net.load_state_dict(state_dict['net'])
epochs = state_dict['epochs']
min_loss = state_dict['min_loss']
acc = state_dict['acc'] if 'acc' in state_dict.keys() else 0.0
print("=" * 40)
print('Resume from epoch: {}, min-loss: {} acc: {}'.format(epochs, min_loss, acc))
print("=" * 40)
trainer.cuda()
trainer.eval()
pred_acc_list = []
test_list = os.listdir(opts.input_dir)
test_list = [os.path.join(opts.input_dir, x) for x in test_list]
test_list = [x for x in test_list if 'input' in os.path.basename(x)]
# # original version for cat and dog
# transform = transforms.Compose([transforms.Resize([config['new_size'], config['new_size']]),
# transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transform = transforms.Compose([transforms.Resize([config['crop_image_height'], config['crop_image_width']]),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
to_tensor = transforms.ToTensor()
opts.log_name = os.path.basename(opts.input_dir) + '-' + opts.log_name
log_pwd = os.path.join(opts.output_dir, opts.log_name)
check_dir(opts.output_dir)
accuracy_list = []
with torch.no_grad():
t_bar = tqdm(test_list)
t_bar.set_description('Processing')
with open(log_pwd, 'w') as fid_w:
for image_info in t_bar:
img_pwd = image_info
image = Image.open(img_pwd).convert('RGB')
# cv2.imshow('{}'.format(CLASS_ID), np.asarray(image)[:, :, ::-1])
# cv2.waitKey()
label = int(os.path.dirname(img_pwd).split(os.sep)[-1].split('-')[0])
image = transform(image)
image = image.unsqueeze(0).cuda()
pred = trainer.net(image)
ps = torch.exp(pred)
top_p, top_class = ps.topk(1, dim=1)
accuracy = int(top_class.item() == label)
accuracy_list.append(float(accuracy))
if accuracy < 1:
line_info = '{} | pred: {}, label: {}'.format(img_pwd, top_class.item(), label)
print(line_info)
fid_w.write(line_info + '\n')
# cv2.imshow('error result', cv2.imread(img_pwd))
# cv2.waitKey(10)
mean_acc = np.mean(accuracy_list)
print('\n<{}> Test result: accuracy: {}'.format(get_local_time(), mean_acc))
fid_w.write('\n<{}> Test result: accuracy: {}\n'.format(get_local_time(), mean_acc))