forked from LiJiaBei-7/nrccr
-
Notifications
You must be signed in to change notification settings - Fork 2
/
tester_img.py
131 lines (101 loc) · 5 KB
/
tester_img.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import sys
import json
import torch
import logging
import argparse
import evaluation
from model import get_model
from validate import norm_score, cal_perf
import util.tag_data_provider_img as data
import util.metrics as metrics
from basic.util import read_dict, log_config
from basic.common import makedirsforfile, checkToSkip
from test_base import parse_args
def main():
opt = parse_args()
print(json.dumps(vars(opt), indent=2))
rootpath = opt.rootpath
resume = os.path.join(opt.logger_name, opt.checkpoint_name)
if not os.path.exists(resume):
logging.info(resume + ' not exists.')
sys.exit(0)
checkpoint = torch.load(resume)
start_epoch = checkpoint['epoch']
best_rsum = checkpoint['best_rsum']
print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})"
.format(resume, start_epoch, best_rsum))
options = checkpoint['opt']
# collection setting
testCollection = opt.testCollection
collections_pathname = options.collections_pathname
collections_pathname['test'] = testCollection
trainCollection = options.trainCollection
output_dir = resume.replace(trainCollection, testCollection)
if 'checkpoints' in output_dir:
output_dir = output_dir.replace('/checkpoints/', '/results/')
else:
output_dir = output_dir.replace('/%s/' % options.cv_name, '/results/%s/%s/' % (options.cv_name, trainCollection))
pred_error_matrix_file = os.path.join(output_dir, 'pred_errors_matrix.pth.tar')
if checkToSkip(pred_error_matrix_file, opt.overwrite):
sys.exit(0)
makedirsforfile(pred_error_matrix_file)
log_config(output_dir)
logging.info(json.dumps(vars(opt), indent=2))
# data loader prepare
if 'de' in options.data_type and options.task == 1:
task = '_task1'
else:
task = ''
tmp = options.data_type.split('_')[-1].split('2')
lang_type = tmp[-1] + '2' + tmp[0]
test_cap = os.path.join(rootpath, collections_pathname['test'], 'TextData', '%s%s_%s_2016%s.caption.txt' %(testCollection, opt.split, tmp[-1], task))
test_cap_trans = os.path.join(rootpath, collections_pathname['test'], 'TextData', '%s%s_google_%s_2016%s.caption.txt' %(testCollection, opt.split, lang_type, task))
caption_files = {'test': test_cap}
caption_files_trans = {'test': test_cap_trans}
if options.img_encoder != 'clip':
# use pre-extracted frame features
visual_feature_name = {'test': 'test_2016_flickr-resnet152-avgpool.npy'}
visual_feat_path = os.path.join(rootpath, collections_pathname['test'], 'FeatureData', options.visual_feature,
visual_feature_name['test'])
import numpy as np
visual_feats = {'test': np.load(visual_feat_path, encoding="latin1")}
assert options.visual_feat_dim == visual_feats['test'].shape[-1]
else:
# use clip
visual_feats = {'test': 'test'}
# Construct the model
model = get_model(options.model)(options)
model.parallel()
model.load_state_dict(checkpoint['model'])
model.Eiters = checkpoint['Eiters']
model.val_start()
# set data loader
if options.collection == 'multi30k':
image_id_name = {'test': 'test_id_2016.txt'}
elif options.collection == 'mscoco':
image_id_name = {'test': f'{opt.data_type.split("2")[-1]}_test_id.txt'}
image_id_file = os.path.join(rootpath, collections_pathname['test'], 'FeatureData', options.visual_feature, image_id_name['test'])
test_image_ids_list = []
with open(image_id_file) as f:
for line in f.readlines():
test_image_ids_list.append(line.strip())
vid_data_loader = data.get_vis_data_loader(options, visual_feats['test'], options.img_path, opt.batch_size, opt.workers, image_ids=test_image_ids_list)
text_data_loader = data.get_txt_data_loader(options, caption_files['test'], caption_files_trans['test'], opt.batch_size, opt.workers, is_test=True)
# get embedding
video_embs, video_ids = evaluation.encode_text_or_vid(model.embed_vis, vid_data_loader)
cap_embs, cap_trans_embs, caption_ids = evaluation.encode_text_hybrid(model.embed_txt, text_data_loader)
v2t_gt, t2v_gt = metrics.get_gt(video_ids, caption_ids)
logging.info("write into: %s" % output_dir)
t2v_all_errors_1 = evaluation.cal_error(video_embs, cap_embs, options.measure)
t2v_all_errors_2 = evaluation.cal_error(video_embs, cap_trans_embs, options.measure)
w = 0.5
print(w,'------')
t2v_all_errors_1 = norm_score(t2v_all_errors_1)
t2v_all_errors_2 = norm_score(t2v_all_errors_2)
t2v_tag_all_errors = w * t2v_all_errors_1 + (1-w) * t2v_all_errors_2
cal_perf(t2v_tag_all_errors, v2t_gt, t2v_gt)
torch.save({'errors': t2v_tag_all_errors, 'videos': video_ids, 'captions': caption_ids}, pred_error_matrix_file)
logging.info("write into: %s" % pred_error_matrix_file)
if __name__ == '__main__':
main()