-
Notifications
You must be signed in to change notification settings - Fork 576
/
Copy pathreid_metric.py
97 lines (82 loc) · 3.18 KB
/
reid_metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import numpy as np
import torch
from ignite.metrics import Metric
from data.datasets.eval_reid import eval_func
from .re_ranking import re_ranking
class R1_mAP(Metric):
def __init__(self, num_query, max_rank=50, feat_norm='yes'):
super(R1_mAP, self).__init__()
self.num_query = num_query
self.max_rank = max_rank
self.feat_norm = feat_norm
def reset(self):
self.feats = []
self.pids = []
self.camids = []
def update(self, output):
feat, pid, camid = output
self.feats.append(feat)
self.pids.extend(np.asarray(pid))
self.camids.extend(np.asarray(camid))
def compute(self):
feats = torch.cat(self.feats, dim=0)
if self.feat_norm == 'yes':
print("The test feature is normalized")
feats = torch.nn.functional.normalize(feats, dim=1, p=2)
# query
qf = feats[:self.num_query]
q_pids = np.asarray(self.pids[:self.num_query])
q_camids = np.asarray(self.camids[:self.num_query])
# gallery
gf = feats[self.num_query:]
g_pids = np.asarray(self.pids[self.num_query:])
g_camids = np.asarray(self.camids[self.num_query:])
m, n = qf.shape[0], gf.shape[0]
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
distmat.addmm_(1, -2, qf, gf.t())
distmat = distmat.cpu().numpy()
cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
return cmc, mAP
class R1_mAP_reranking(Metric):
def __init__(self, num_query, max_rank=50, feat_norm='yes'):
super(R1_mAP_reranking, self).__init__()
self.num_query = num_query
self.max_rank = max_rank
self.feat_norm = feat_norm
def reset(self):
self.feats = []
self.pids = []
self.camids = []
def update(self, output):
feat, pid, camid = output
self.feats.append(feat)
self.pids.extend(np.asarray(pid))
self.camids.extend(np.asarray(camid))
def compute(self):
feats = torch.cat(self.feats, dim=0)
if self.feat_norm == 'yes':
print("The test feature is normalized")
feats = torch.nn.functional.normalize(feats, dim=1, p=2)
# query
qf = feats[:self.num_query]
q_pids = np.asarray(self.pids[:self.num_query])
q_camids = np.asarray(self.camids[:self.num_query])
# gallery
gf = feats[self.num_query:]
g_pids = np.asarray(self.pids[self.num_query:])
g_camids = np.asarray(self.camids[self.num_query:])
# m, n = qf.shape[0], gf.shape[0]
# distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
# torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
# distmat.addmm_(1, -2, qf, gf.t())
# distmat = distmat.cpu().numpy()
print("Enter reranking")
distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3)
cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
return cmc, mAP