-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtest.py
46 lines (38 loc) · 2.17 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import argparse
import torch
import os
from torch.utils.tensorboard import SummaryWriter
from models import Model
from loss import loss_kd
from dataloader import get_test_loader
from utils import Params, AverageMeter
from tqdm import tqdm
torch.manual_seed(0)
parser = argparse.ArgumentParser()
parser.add_argument('--image_size', type=int, default=224, help='the height / width of the input image to network')
parser.add_argument('--params_dir', type=str, default="params", help='the directory of hyper parameters')
parser.add_argument('-m', '--model_name', type=str, default='base', help='the name of backbone network')
parser.add_argument('--log_path', type=str, default='logs', help="directory to save train log")
parser.add_argument('--epoch', type=int, default=0, help='value of current epoch')
parser.add_argument('--num_epoch', type=int, default=89, help='the number of epoch in train')
parser.add_argument('--decay_epoch', type=int, default=30, help='the number of decay epoch in train')
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help="path to saved models (to continue training)")
parser.add_argument('--num_classes', type=int, default=100, help='the number of classes')
parser.add_argument('--dataset', type=str, default='cifar100', help='the name of dataset')
parser.add_argument('--is_distill', type=bool, default=True)
args = parser.parse_args()
if __name__ == '__main__':
params = Params(os.path.join(args.params_dir, f'{args.model_name}.json'))
acc = AverageMeter()
net = Model(args.num_classes, params)
net.load_params(os.path.join(args.checkpoint_dir, args.dataset, params.model_name, f'80.pth'))
writer = SummaryWriter(args.log_path)
criterion = loss_kd
test_loader = get_test_loader(args.image_size, params.batch_size, args.dataset)
with torch.no_grad():
for images, targets in tqdm(test_loader, desc=f'{params.model_name} Testing...'):
images: torch.Tensor = images.to(net.device)
targets: torch.Tensor = targets.to(net.device)
preds: torch.Tensor = net.predict_image(images)
acc.update((preds == targets).sum().item()/images.shape[0])
print(f'{acc.avg * 100:.4f}%')