-
Notifications
You must be signed in to change notification settings - Fork 1
/
eval_ensemble.py
103 lines (84 loc) · 3.47 KB
/
eval_ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import numpy as np
import time
import os
from six.moves import cPickle
import opts
import models
from dataloader import *
from dataloaderraw import *
import eval_utils
import argparse
import misc.utils as utils
import torch
# Input arguments and options
parser = argparse.ArgumentParser()
# Input paths
parser.add_argument('--ids', nargs='+', required=True, help='id of the models to ensemble')
parser.add_argument('--weights', nargs='+', required=False, default=None, help='id of the models to ensemble')
# parser.add_argument('--models', nargs='+', required=True
# help='path to model to evaluate')
# parser.add_argument('--infos_paths', nargs='+', required=True, help='path to infos to evaluate')
opts.add_eval_options(parser)
opt = parser.parse_args()
model_infos = []
model_paths = []
for id in opt.ids:
if '-' in id:
id, app = id.split('-')
app = '-'+app
else:
app = ''
model_infos.append(utils.pickle_load(open('log_%s/infos_%s%s.pkl' %(id, id, app), 'rb')))
model_paths.append('log_%s/model%s.pth' %(id,app))
# Load one infos
infos = model_infos[0]
# override and collect parameters
replace = ['input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5', 'input_json', 'batch_size', 'id']
for k in replace:
setattr(opt, k, getattr(opt, k) or getattr(infos['opt'], k, ''))
vars(opt).update({k: vars(infos['opt'])[k] for k in vars(infos['opt']).keys() if k not in vars(opt)}) # copy over options from model
opt.use_box = max([getattr(infos['opt'], 'use_box', 0) for infos in model_infos])
assert max([getattr(infos['opt'], 'norm_att_feat', 0) for infos in model_infos]) == max([getattr(infos['opt'], 'norm_att_feat', 0) for infos in model_infos]), 'Not support different norm_att_feat'
assert max([getattr(infos['opt'], 'norm_box_feat', 0) for infos in model_infos]) == max([getattr(infos['opt'], 'norm_box_feat', 0) for infos in model_infos]), 'Not support different norm_box_feat'
vocab = infos['vocab'] # ix -> word mapping
# Setup the model
from models.AttEnsemble import AttEnsemble
_models = []
for i in range(len(model_infos)):
model_infos[i]['opt'].start_from = None
model_infos[i]['opt'].vocab = vocab
tmp = models.setup(model_infos[i]['opt'])
tmp.load_state_dict(torch.load(model_paths[i]))
_models.append(tmp)
if opt.weights is not None:
opt.weights = [float(_) for _ in opt.weights]
model = AttEnsemble(_models, weights=opt.weights)
model.seq_length = opt.max_length
model.cuda()
model.eval()
crit = utils.LanguageModelCriterion()
# Create the Data Loader instance
if len(opt.image_folder) == 0:
loader = DataLoader(opt)
else:
loader = DataLoaderRaw({'folder_path': opt.image_folder,
'coco_json': opt.coco_json,
'batch_size': opt.batch_size,
'cnn_model': opt.cnn_model})
# When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json
# So make sure to use the vocab in infos file.
loader.ix_to_word = infos['vocab']
opt.id = '+'.join([_+str(__) for _,__ in zip(opt.ids, opt.weights)])
# Set sample options
loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader,
vars(opt))
print('loss: ', loss)
if lang_stats:
print(lang_stats)
if opt.dump_json == 1:
# dump the json
json.dump(split_predictions, open('vis/vis.json', 'w'))