-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
71 lines (49 loc) · 2.03 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm
from models import MyDataset
from models import ccnet
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('\ndevice-> ', device, '\n\n')
test_set = './data/test_Tongji.txt'
testset =MyDataset(txt=test_set, transforms=None, train=False)
batch_size = 1024
data_loader_test = DataLoader(dataset=testset, batch_size=batch_size, shuffle=False)
net = ccnet(num_classes=600, weight=0.8) # IITD: 460 KTU: 145 Tongji: 600 REST: 358 XJTU: 200
net.load_state_dict(torch.load('/media/Storage4/mengqingguo/code/CCNet/output/checkpoint/net_params.pth'))
net.to(device)
net.eval()
# feature extraction:
featDB_test = []
iddb_test = []
with torch.no_grad():
for batch_id, (data, target) in enumerate(tqdm(data_loader_test)):
data = data[0].to(device)
target = target.to(device)
# feature extraction
codes = net.getFeatureCode(data)
codes = codes.cpu().detach().numpy()
y = target.cpu().detach().numpy()
if batch_id == 0:
featDB_test = codes
iddb_test = y
else:
featDB_test = np.concatenate((featDB_test, codes), axis=0)
iddb_test = np.concatenate((iddb_test, y))
print('completed feature extraction for test set.')
print('(number of samples, feature vector dimensionality): ', featDB_test.shape)
print('\n')
feat1 = featDB_test[0]
feat2 = featDB_test[1]
feat3 = featDB_test[-1]
# feature matching: feat1 vs feat2
cosdis =np.dot(feat1,feat2)
dis = np.arccos(np.clip(cosdis, -1, 1))/np.pi # 0~1 # np.arccos()计算余弦值的反余弦(即角度)
print('matching distance, label1 vs label2: \t%.2f, %d vs %d'%(dis, iddb_test[0], iddb_test[1]))
# feature matching: feat1 vs feat3
cosdis =np.dot(feat1,feat3)
dis = np.arccos(np.clip(cosdis, -1, 1))/np.pi
print('matching distance, label1 vs label3: \t%.2f, %d vs %d'%(dis, iddb_test[0], iddb_test[-1]))