-
Notifications
You must be signed in to change notification settings - Fork 1
/
validation.py
86 lines (73 loc) · 2.44 KB
/
validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import sys
import logging
import warnings
import dataset
from net_val import WSDHQ
from util import set_logger, DataLoader
warnings.filterwarnings("ignore", category = DeprecationWarning)
warnings.filterwarnings("ignore", category = FutureWarning)
### Define input arguments
ds = sys.argv[1] # dataset
model_weight = sys.argv[2]
gpu = sys.argv[3]
subspace_num = int(model_weight.split('_')[2].split('=')[-1]) // 8
if ds not in ['nuswide', 'flickr']:
print(f"unknown dataset '{ds}', use 'nus-wide' by default.")
ds = 'nuswide'
ds2path_map = {
'nuswide': 'nus-wide',
'flickr': 'flickr25k'
}
ds2img_data_root_map = {
'nuswide': './datasets/nus-wide/',
'flickr': './datasets/flickr25k/'
}
config = {
## basic settings
'dataset': ds,
'train': False,
'device': '/gpu:' + gpu,
'batch_size': 100,
'output_dim': 300,
'topK': 5000, # mAP@'topK'
# ## dataset, tags and checkpoint I/Os
'img_data_root': ds2img_data_root_map[ds],
'test_img_fpath': f"./data/{ds2path_map[ds]}/test_img.txt",
'test_label_fpath': f"./data/{ds2path_map[ds]}/test_label.txt",
'database_img_fpath': f"./data/{ds2path_map[ds]}/database_img.txt",
'database_label_fpath': f"./data/{ds2path_map[ds]}/database_label.txt",
'final_tag_embs_fpath': f"./data/{ds2path_map[ds]}/tags/FinalTagEmbs.txt",
## backbone
'img_model': 'alexnet',
'model_weights_fpath': model_weight,
## quantization
'max_iter_update_b': 3,
'code_batch_size': 50 * 14,
'subspace_num': subspace_num,
'subcenter_num': 256,
## evaluator
'reload_if_exists': True, # if true, then use the cached code file (if exists) for evaluation
'evaluator_type': 'np', # choose 'np' or 'tf'
'metric_mode': '111', # AQD, SQD, feats; '1': enabled, '0': disable
}
set_logger(config)
logging.info("prepare dataset")
qry_dataset, db_dataset = dataset.import_validation(config)
logging.info("prepare data loader")
db_dataloader = DataLoader(
db_dataset,
config['output_dim'],
config['subspace_num'] * config['subcenter_num'])
qry_dataloader = DataLoader(
qry_dataset,
config['output_dim'],
config['subspace_num'] * config['subcenter_num'])
logging.info("prepare model")
model = WSDHQ(config)
logging.info("begin validation")
model.validation(
qry_dataloader, db_dataloader, config['topK'],
config['reload_if_exists'],
config['evaluator_type'],
config['metric_mode'])
logging.info("finish validation")