-
Notifications
You must be signed in to change notification settings - Fork 131
/
prepro_vocab.py
executable file
·85 lines (74 loc) · 3.24 KB
/
prepro_vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import re
import json
import argparse
import numpy as np
def build_vocab(vids, params):
count_thr = params['word_count_threshold']
# count up the number of words
counts = {}
for vid, caps in vids.items():
for cap in caps['captions']:
ws = re.sub(r'[.!,;?]', ' ', cap).split()
for w in ws:
counts[w] = counts.get(w, 0) + 1
# cw = sorted([(count, w) for w, count in counts.items()], reverse=True)
total_words = sum(counts.values())
bad_words = [w for w, n in counts.items() if n <= count_thr]
vocab = [w for w, n in counts.items() if n > count_thr]
bad_count = sum(counts[w] for w in bad_words)
print('number of bad words: %d/%d = %.2f%%' %
(len(bad_words), len(counts), len(bad_words) * 100.0 / len(counts)))
print('number of words in vocab would be %d' % (len(vocab), ))
print('number of UNKs: %d/%d = %.2f%%' %
(bad_count, total_words, bad_count * 100.0 / total_words))
# lets now produce the final annotations
if bad_count > 0:
# additional special UNK token we will use below to map infrequent words to
print('inserting the special UNK token')
vocab.append('<UNK>')
for vid, caps in vids.items():
caps = caps['captions']
vids[vid]['final_captions'] = []
for cap in caps:
ws = re.sub(r'[.!,;?]', ' ', cap).split()
caption = [
'<sos>'] + [w if counts.get(w, 0) > count_thr else '<UNK>' for w in ws] + ['<eos>']
vids[vid]['final_captions'].append(caption)
return vocab
def main(params):
videos = json.load(open(params['input_json'], 'r'))['sentences']
video_caption = {}
for i in videos:
if i['video_id'] not in video_caption.keys():
video_caption[i['video_id']] = {'captions': []}
video_caption[i['video_id']]['captions'].append(i['caption'])
# create the vocab
vocab = build_vocab(video_caption, params)
itow = {i + 2: w for i, w in enumerate(vocab)}
wtoi = {w: i + 2 for i, w in enumerate(vocab)} # inverse table
wtoi['<eos>'] = 0
itow[0] = '<eos>'
wtoi['<sos>'] = 1
itow[1] = '<sos>'
out = {}
out['ix_to_word'] = itow
out['word_to_ix'] = wtoi
out['videos'] = {'train': [], 'val': [], 'test': []}
videos = json.load(open(params['input_json'], 'r'))['videos']
for i in videos:
out['videos'][i['split']].append(int(i['id']))
json.dump(out, open(params['info_json'], 'w'))
json.dump(video_caption, open(params['caption_json'], 'w'))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# input json
parser.add_argument('--input_json', type=str, default='data/videodatainfo_2017.json',
help='msr_vtt videoinfo json')
parser.add_argument('--info_json', default='data/info.json',
help='info about iw2word and word2ix')
parser.add_argument('--caption_json', default='data/caption.json', help='caption json file')
parser.add_argument('--word_count_threshold', default=1, type=int,
help='only words that occur more than this number of times will be put in vocab')
args = parser.parse_args()
params = vars(args) # convert to ordinary dict
main(params)