Skip to content

Commit

Permalink
Add visualization code
Browse files Browse the repository at this point in the history
  • Loading branch information
anosorae authored Mar 29, 2024
1 parent 8b5af4d commit c698f85
Showing 1 changed file with 97 additions and 0 deletions.
97 changes: 97 additions & 0 deletions visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
import numpy as np
import os.path as op
import torch.nn.functional as F
from datasets import build_dataloader
from utils.checkpoint import Checkpointer
from model import build_model
from utils.metrics import Evaluator
from utils.iotools import load_train_configs
import random
import matplotlib.pyplot as plt
from PIL import Image
from datasets.cuhkpedes import CUHKPEDES


config_file = '/xxx/configs.yaml'
args = load_train_configs(config_file)
args.batch_size = 1024
args.training = False
device = "cuda"
test_img_loader, test_txt_loader = build_dataloader(args)
model = build_model(args)
checkpointer = Checkpointer(model)
checkpointer.load(f=op.join(args.output_dir, 'best.pth'))
model.to(device)

evaluator = Evaluator(test_img_loader, test_txt_loader)

qfeats, gfeats, qids, gids = evaluator._compute_embedding(model.eval())
qfeats = F.normalize(qfeats, p=2, dim=1) # text features
gfeats = F.normalize(gfeats, p=2, dim=1) # image features

similarity = qfeats @ gfeats.t()
# acclerate sort with topk
_, indices = torch.topk(similarity, k=10, dim=1, largest=True, sorted=True) # q * topk

dataset = CUHKPEDES(root='./data')
test_dataset = dataset.test

img_paths = test_dataset['img_paths']
captions = test_dataset['captions']
gt_img_paths = test_dataset['gt_img_paths']

def get_one_query_caption_and_result_by_id(idx, indices, qids, gids, captions, img_paths, gt_img_paths):
query_caption = captions[idx]
query_id = qids[idx]
image_paths = [img_paths[j] for j in indices[idx]]
image_ids = gids[indices[idx]]
gt_image_path = gt_img_paths[idx]
return query_id, image_ids, query_caption, image_paths, gt_image_path

def plot_retrieval_images(query_id, image_ids, query_caption, image_paths, gt_img_path, fname=None):
print(query_id)
print(image_ids)
print(query_caption)
fig = plt.figure()
col = len(image_paths)

# plot ground truth image
plt.subplot(1, col+1, 1)
img = Image.open(gt_img_path)
img = img.resize((128, 256))
plt.imshow(img)
plt.xticks([])
plt.yticks([])

for i in range(col):
plt.subplot(1, col+1, i+2)
img = Image.open(image_paths[i])

bwith = 2 # 边框宽度设置为2
ax = plt.gca() # 获取边框
if image_ids[i] == query_id:
ax.spines['top'].set_color('lawngreen')
ax.spines['right'].set_color('lawngreen')
ax.spines['bottom'].set_color('lawngreen')
ax.spines['left'].set_color('lawngreen')
else:
ax.spines['top'].set_color('red')
ax.spines['right'].set_color('red')
ax.spines['bottom'].set_color('red')
ax.spines['left'].set_color('red')

img = img.resize((128, 256))
plt.imshow(img)
plt.xticks([])
plt.yticks([])


fig.show()
if fname:
plt.savefig(fname, dpi=300)
# idx is the index of qids(A list of query ids, range from 0 - len(qids))
query_id, image_ids, query_caption, image_paths, gt_img_path = get_one_query_caption_and_result_by_id(0, indices, qids, gids, captions, img_paths, gt_img_paths)
plot_retrieval_images(query_id, image_ids, query_caption, image_paths, gt_img_path)

0 comments on commit c698f85

Please sign in to comment.