-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
68 lines (52 loc) · 2.24 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import argparse
import heapq
from tqdm import tqdm
import time
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="/home/qilang/PythonProjects/ICME/Ppromo/config/smarthome-cs/train.yaml", type=str)
parser.add_argument("--model", default="/home/qilang/PythonProjects/ICME/Ppromo/weights/ppromo_cs.pth", type=str)
args = parser.parse_args()
import yaml
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import random
from torchvision import datasets, transforms
from VideoDataset import Dataset
from net.Ppromo_fmw import ppromo
from net.ctrgcn_att import Model
# from dataset import *
from net.utils.Meter import *
from tensorboardX import SummaryWriter
def val():
f = open(args.config)
yaml_args = yaml.load(f,Loader = yaml.FullLoader)
test_dataset = Dataset(**yaml_args['test_dataset_args'])
val_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=yaml_args['batch_size'], shuffle=True, num_workers=12,pin_memory=True)
model = ppromo(**yaml_args['model_args']).cuda()
weights = torch.load(args.model)
model.load_state_dict(weights['state_dict'])
lr = yaml_args['lr']
print("--------------------------VAL--------------------------------")
model.eval()
top1 = AverageMeter()
top5 = AverageMeter()
with tqdm(total=len(val_dataloader), desc="Test") as pbar:
with torch.no_grad():
for step, (frame_indices, inputs, skl, labels) in enumerate(val_dataloader):
inputs = Variable(inputs.cuda())
frame_indices = Variable(frame_indices.cuda(),requires_grad=False)
skl = Variable(skl.mean(-1).unsqueeze(-1).cuda(),requires_grad=False)
labels = labels.cuda()
pred = model(inputs,skl,frame_indices)
prec1, prec5 = accuracy(pred.data, labels, topk=(1, 5))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
pbar.set_postfix({'top-1 && top-5' : '{:.1f} ,{:.1f}'.format(top1.val,top5.val)})
pbar.update(1)
print ('Top-1: {:.4f}, Top-5: {:.4f}'.format(top1.avg,top5.avg) )
if __name__ == '__main__':
val()