forked from mangye16/ReID-Survey
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from beesk135/cosface
Cosface
- Loading branch information
Showing
17 changed files
with
814 additions
and
99 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
{ | ||
"python.linting.pylintEnabled": true, | ||
"python.linting.enabled": true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import torch | ||
|
||
from .eval_reid import eval_func | ||
|
||
def euclidean_dist(x, y): | ||
m, n = x.size(0), y.size(0) | ||
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) | ||
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() | ||
dist = xx + yy | ||
dist.addmm_(1, -2, x, y.t()) | ||
dist = dist.clamp(min=1e-12).sqrt() | ||
return dist |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import numpy as np | ||
|
||
def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=200): | ||
"""Evaluation with market1501 metric | ||
Key: for each query identity, its gallery images from the same camera view are discarded. | ||
""" | ||
num_q, num_g = distmat.shape | ||
if num_g < max_rank: | ||
max_rank = num_g | ||
print("Note: number of gallery samples is quite small, got {}".format(num_g)) | ||
indices = np.argsort(distmat, axis=1) | ||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) | ||
|
||
# compute cmc curve for each query | ||
all_cmc = [] | ||
all_AP = [] | ||
num_valid_q = 0. # number of valid query | ||
for q_idx in range(num_q): | ||
# get query pid and camid | ||
q_pid = q_pids[q_idx] | ||
q_camid = q_camids[q_idx] | ||
|
||
# remove gallery samples that have the same pid and camid with query | ||
order = indices[q_idx] | ||
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) | ||
keep = np.invert(remove) | ||
|
||
# compute cmc curve | ||
# binary vector, positions with value 1 are correct matches | ||
orig_cmc = matches[q_idx][keep] | ||
if not np.any(orig_cmc): | ||
# this condition is true when query identity does not appear in gallery | ||
# [update:20191029] divide by query | ||
all_AP.append(0) | ||
continue | ||
|
||
cmc = orig_cmc.cumsum() | ||
cmc[cmc > 1] = 1 | ||
|
||
all_cmc.append(cmc[:max_rank]) | ||
num_valid_q += 1. | ||
|
||
# compute average precision | ||
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision | ||
# [update:20191029] support for map@max_rank | ||
orig_cmc = orig_cmc[:max_rank] | ||
if not np.any(orig_cmc): | ||
all_AP.append(0) | ||
continue | ||
num_rel = orig_cmc.sum() | ||
tmp_cmc = orig_cmc.cumsum() | ||
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] | ||
tmp_cmc = np.asarray(tmp_cmc) * orig_cmc | ||
AP = tmp_cmc.sum() / num_rel | ||
all_AP.append(AP) | ||
|
||
# assert num_valid_q > 0, "Error: all query identities do not appear in gallery" | ||
|
||
all_cmc = np.asarray(all_cmc).astype(np.float32) | ||
# [update:20191029] divide by query | ||
all_cmc = all_cmc.sum(0) / num_q | ||
|
||
mAP = np.mean(all_AP) | ||
|
||
return all_cmc, mAP,all_AP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.