-
Notifications
You must be signed in to change notification settings - Fork 6
/
Retrieval.py
72 lines (54 loc) · 2.72 KB
/
Retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from Dataloader import *
def Evaluate_mAP(device, gallery_codes, query_codes, gallery_labels, query_labels, Top_N=None):
num_query = query_labels.shape[0]
mean_AP = 0.0
for i in range(num_query):
retrieval = (query_labels[i, :] @ gallery_labels.t() > 0).float()
hamming_dist = (gallery_codes.shape[1] - query_codes[i, :] @ gallery_codes.t())
retrieval = retrieval[T.argsort(hamming_dist)][:Top_N]
retrieval_cnt = retrieval.sum().int().item()
if retrieval_cnt == 0:
continue
score = T.linspace(1, retrieval_cnt, retrieval_cnt).to(device)
index = (T.nonzero(retrieval == 1, as_tuple=False).squeeze() + 1.0).float().to(device)
mean_AP += (score / index).mean()
mean_AP = mean_AP / num_query
return mean_AP
def DoRetrieval(device, net, Img_dir, Gallery_dir, Query_dir, NB_CLS, Top_N, args):
Gallery_set = Loader(Img_dir, Gallery_dir, NB_CLS)
Gallery_loader = T.utils.data.DataLoader(Gallery_set, batch_size=args.batch_size, num_workers=args.num_workers)
Query_set = Loader(Img_dir, Query_dir, NB_CLS)
Query_loader = T.utils.data.DataLoader(Query_set, batch_size=args.batch_size, num_workers=args.num_workers)
Crop_Normalize = T.nn.Sequential(
Kg.CenterCrop(224),
Kg.Normalize(mean=T.as_tensor([0.485, 0.456, 0.406]), std=T.as_tensor([0.229, 0.224, 0.225]))
)
with T.no_grad():
for i, data in enumerate(Gallery_loader, 0):
gallery_x_batch, gallery_y_batch = data[0].to(device), data[1].to(device)
gallery_x_batch = Crop_Normalize(gallery_x_batch)
if gallery_y_batch.dim() == 1:
gallery_y_batch = T.eye(NB_CLS, device=device)[gallery_y_batch]
outputs = net(gallery_x_batch)
if i == 0:
gallery_c = outputs
gallery_y = gallery_y_batch
else:
gallery_c = T.cat([gallery_c, outputs], 0)
gallery_y = T.cat([gallery_y, gallery_y_batch], 0)
for i, data in enumerate(Query_loader, 0):
query_x_batch, query_y_batch = data[0].to(device), data[1].to(device)
query_x_batch = Crop_Normalize(query_x_batch)
if query_y_batch.dim() == 1:
query_y_batch = T.eye(NB_CLS, device=device)[query_y_batch]
outputs = net(query_x_batch)
if i == 0:
query_c = outputs
query_y = query_y_batch
else:
query_c = T.cat([query_c, outputs], 0)
query_y = T.cat([query_y, query_y_batch], 0)
gallery_c = T.sign(gallery_c)
query_c = T.sign(query_c)
mAP = Evaluate_mAP(device, gallery_c, query_c, gallery_y, query_y, Top_N)
return mAP