-
Notifications
You must be signed in to change notification settings - Fork 136
/
test.py
78 lines (62 loc) · 3.06 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import random
from data import ImageDetectionsField, TextField, RawField
from data import COCO, DataLoader
import evaluation
from models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory
import torch
from tqdm import tqdm
import argparse
import pickle
import numpy as np
random.seed(1234)
torch.manual_seed(1234)
np.random.seed(1234)
def predict_captions(model, dataloader, text_field):
import itertools
model.eval()
gen = {}
gts = {}
with tqdm(desc='Evaluation', unit='it', total=len(dataloader)) as pbar:
for it, (images, caps_gt) in enumerate(iter(dataloader)):
images = images.to(device)
with torch.no_grad():
out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], 5, out_size=1)
caps_gen = text_field.decode(out, join_words=False)
for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
gen['%d_%d' % (it, i)] = [gen_i.strip(), ]
gts['%d_%d' % (it, i)] = gts_i
pbar.update()
gts = evaluation.PTBTokenizer.tokenize(gts)
gen = evaluation.PTBTokenizer.tokenize(gen)
scores, _ = evaluation.compute_scores(gts, gen)
return scores
if __name__ == '__main__':
device = torch.device('cuda')
parser = argparse.ArgumentParser(description='Meshed-Memory Transformer')
parser.add_argument('--batch_size', type=int, default=10)
parser.add_argument('--workers', type=int, default=0)
parser.add_argument('--features_path', type=str)
parser.add_argument('--annotation_folder', type=str)
args = parser.parse_args()
print('Meshed-Memory Transformer Evaluation')
# Pipeline for image regions
image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=50, load_in_tmp=False)
# Pipeline for text
text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, tokenize='spacy',
remove_punctuation=True, nopoints=False)
# Create the dataset
dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder)
_, _, test_dataset = dataset.splits
text_field.vocab = pickle.load(open('vocab.pkl', 'rb'))
# Model and dataloaders
encoder = MemoryAugmentedEncoder(3, 0, attention_module=ScaledDotProductAttentionMemory,
attention_module_kwargs={'m': 40})
decoder = MeshedDecoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['<pad>'])
model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder).to(device)
data = torch.load('meshed_memory_transformer.pth')
model.load_state_dict(data['state_dict'])
dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()})
dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size, num_workers=args.workers)
scores = predict_captions(model, dict_dataloader_test, text_field)
print(scores)