-
Notifications
You must be signed in to change notification settings - Fork 27
/
visualize.py
97 lines (82 loc) · 3.28 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
import numpy as np
import os.path as op
import torch.nn.functional as F
from datasets import build_dataloader
from utils.checkpoint import Checkpointer
from model import build_model
from utils.metrics import Evaluator
from utils.iotools import load_train_configs
import random
import matplotlib.pyplot as plt
from PIL import Image
from datasets.cuhkpedes import CUHKPEDES
config_file = '/xxx/configs.yaml'
args = load_train_configs(config_file)
args.batch_size = 1024
args.training = False
device = "cuda"
test_img_loader, test_txt_loader = build_dataloader(args)
model = build_model(args)
checkpointer = Checkpointer(model)
checkpointer.load(f=op.join(args.output_dir, 'best.pth'))
model.to(device)
evaluator = Evaluator(test_img_loader, test_txt_loader)
qfeats, gfeats, qids, gids = evaluator._compute_embedding(model.eval())
qfeats = F.normalize(qfeats, p=2, dim=1) # text features
gfeats = F.normalize(gfeats, p=2, dim=1) # image features
similarity = qfeats @ gfeats.t()
# acclerate sort with topk
_, indices = torch.topk(similarity, k=10, dim=1, largest=True, sorted=True) # q * topk
dataset = CUHKPEDES(root='./data')
test_dataset = dataset.test
img_paths = test_dataset['img_paths']
captions = test_dataset['captions']
gt_img_paths = test_dataset['gt_img_paths']
def get_one_query_caption_and_result_by_id(idx, indices, qids, gids, captions, img_paths, gt_img_paths):
query_caption = captions[idx]
query_id = qids[idx]
image_paths = [img_paths[j] for j in indices[idx]]
image_ids = gids[indices[idx]]
gt_image_path = gt_img_paths[idx]
return query_id, image_ids, query_caption, image_paths, gt_image_path
def plot_retrieval_images(query_id, image_ids, query_caption, image_paths, gt_img_path, fname=None):
print(query_id)
print(image_ids)
print(query_caption)
fig = plt.figure()
col = len(image_paths)
# plot ground truth image
plt.subplot(1, col+1, 1)
img = Image.open(gt_img_path)
img = img.resize((128, 256))
plt.imshow(img)
plt.xticks([])
plt.yticks([])
for i in range(col):
plt.subplot(1, col+1, i+2)
img = Image.open(image_paths[i])
bwith = 2 # 边框宽度设置为2
ax = plt.gca() # 获取边框
if image_ids[i] == query_id:
ax.spines['top'].set_color('lawngreen')
ax.spines['right'].set_color('lawngreen')
ax.spines['bottom'].set_color('lawngreen')
ax.spines['left'].set_color('lawngreen')
else:
ax.spines['top'].set_color('red')
ax.spines['right'].set_color('red')
ax.spines['bottom'].set_color('red')
ax.spines['left'].set_color('red')
img = img.resize((128, 256))
plt.imshow(img)
plt.xticks([])
plt.yticks([])
fig.show()
if fname:
plt.savefig(fname, dpi=300)
# idx is the index of qids(A list of query ids, range from 0 - len(qids))
query_id, image_ids, query_caption, image_paths, gt_img_path = get_one_query_caption_and_result_by_id(0, indices, qids, gids, captions, img_paths, gt_img_paths)
plot_retrieval_images(query_id, image_ids, query_caption, image_paths, gt_img_path)