forked from parksunwoo/show_attend_and_tell_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
prepro.py
114 lines (98 loc) · 3.63 KB
/
prepro.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import pickle
from collections import Counter
import argparse
import nltk
from PIL import Image
from pycocotools.coco import COCO
class Vocabulary(object):
"""Simple vocabulary wrapper."""
def __init__(self):
self.word2idx = {}
self.idx2word = {}
self.idx = 0
def add_word(self, word):
if not word in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def __call__(self, word):
if not word in self.word2idx:
return self.word2idx['<unk>']
return self.word2idx[word]
def __len__(self):
return len(self.word2idx)
def build_vocab(json, threshold):
"""Build a simple vocabulary wrapper."""
coco = COCO(json)
counter = Counter()
ids = coco.anns.keys()
for i, id in enumerate(ids):
caption = str(coco.anns[id]['caption'])
tokens = nltk.tokenize.word_tokenize(caption.lower())
counter.update(tokens)
if i % 1000 == 0:
print("[%d/%d] Tokenized the captions." %(i, len(ids)))
# If the word frequency is less than 'threshold', then the word is discarded.
words = [word for word, cnt in counter.items() if cnt >= threshold]
# Creates a vocab wrapper and add some special tokens.
vocab = Vocabulary()
vocab.add_word('<pad>')
vocab.add_word('<start>')
vocab.add_word('<end>')
vocab.add_word('<unk>')
# Adds the words to the vocabulary.
for i, word in enumerate(words):
vocab.add_word(word)
return vocab
def resize_image(image):
width, height = image.size
if width > height:
left = (width - height) / 2
right = width - left
top = 0
bottom = height
else:
top = (height - width) / 2
bottom = height - top
left = 0
right = width
image = image.crop((left, top, right, bottom))
image = image.resize([224, 224], Image.ANTIALIAS)
return image
def main(args):
vocab = build_vocab(json=args.caption_path,
threshold=args.threshold)
vocab_path = args.vocab_path
with open(vocab_path, 'wb') as f:
pickle.dump(vocab, f)
print("Total vocabulary size: %d" %len(vocab))
print("Saved the vocabulary wrapper to '%s'" %vocab_path)
print("Start resize_image")
splits = ['train', 'val']
for split in splits:
folder = './image/%s2014' %split
resized_folder = './image/%s2014_resized/' %split
if not os.path.exists(resized_folder):
os.makedirs(resized_folder)
print 'Start resizing %s images.' %split
image_files = os.listdir(folder)
num_images = len(image_files)
for i, image_file in enumerate(image_files):
with open(os.path.join(folder, image_file), 'r+b') as f:
with Image.open(f) as image:
image = resize_image(image)
image.save(os.path.join(resized_folder, image_file), image.format)
if i % 100 == 0:
print 'Resized images: %d/%d' %(i, num_images)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--caption_path', type=str,
default='/usr/share/mscoco/annotations/captions_train2014.json',
help='path for train annotation file')
parser.add_argument('--vocab_path', type=str, default='./data/vocab.pkl',
help='path for saving vocabulary wrapper')
parser.add_argument('--threshold', type=int, default=4,
help='minimum word count threshold')
args = parser.parse_args()
main(args)