-
Notifications
You must be signed in to change notification settings - Fork 71
/
voting_evaluate_cls.py
101 lines (87 loc) · 3.43 KB
/
voting_evaluate_cls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_sched
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import os
from torchvision import transforms
from models import RSCNN_SSN_Cls as RSCNN_SSN
from data import ModelNet40Cls
import utils.pytorch_utils as pt_utils
import utils.pointnet2_utils as pointnet2_utils
import data.data_utils as d_utils
import argparse
import random
import yaml
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
parser = argparse.ArgumentParser(description='Relation-Shape CNN Shape Classification Voting Evaluation')
parser.add_argument('--config', default='cfgs/config_ssn_cls.yaml', type=str)
NUM_REPEAT = 300
NUM_VOTE = 10
def main():
args = parser.parse_args()
with open(args.config) as f:
config = yaml.load(f)
for k, v in config['common'].items():
setattr(args, k, v)
test_transforms = transforms.Compose([
d_utils.PointcloudToTensor()
])
test_dataset = ModelNet40Cls(num_points = args.num_points, root = args.data_root, transforms=test_transforms, train=False)
test_dataloader = DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=int(args.workers),
pin_memory=True
)
model = RSCNN_SSN(num_classes = args.num_classes, input_channels = args.input_channels, relation_prior = args.relation_prior, use_xyz = True)
model.cuda()
if args.checkpoint is not '':
model.load_state_dict(torch.load(args.checkpoint))
print('Load model successfully: %s' % (args.checkpoint))
# evaluate
PointcloudScale = d_utils.PointcloudScale() # initialize random scaling
model.eval()
global_acc = 0
for i in range(NUM_REPEAT):
preds = []
labels = []
for j, data in enumerate(test_dataloader, 0):
points, target = data
points, target = points.cuda(), target.cuda()
points, target = Variable(points, volatile=True), Variable(target, volatile=True)
# fastest point sampling
fps_idx = pointnet2_utils.furthest_point_sample(points, 1200) # (B, npoint)
pred = 0
for v in range(NUM_VOTE):
new_fps_idx = fps_idx[:, np.random.choice(1200, args.num_points, False)]
new_points = pointnet2_utils.gather_operation(points.transpose(1, 2).contiguous(), new_fps_idx).transpose(1, 2).contiguous()
if v > 0:
new_points.data = PointcloudScale(new_points.data)
pred += F.softmax(model(new_points), dim = 1)
pred /= NUM_VOTE
target = target.view(-1)
_, pred_choice = torch.max(pred.data, -1)
preds.append(pred_choice)
labels.append(target.data)
preds = torch.cat(preds, 0)
labels = torch.cat(labels, 0)
acc = (preds == labels).sum() / labels.numel()
if acc > global_acc:
global_acc = acc
print('Repeat %3d \t Acc: %0.6f' % (i + 1, acc))
print('\nBest voting acc: %0.6f' % (global_acc))
if __name__ == '__main__':
main()