-
Notifications
You must be signed in to change notification settings - Fork 0
/
vocab.py
219 lines (183 loc) · 7.97 KB
/
vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CS224N 2018-19: Homework 4
vocab.py: Vocabulary Generation
Pencheng Yin <pcyin@cs.cmu.edu>
Sahil Chopra <schopra8@stanford.edu>
Usage:
vocab.py --train-src=<file> --train-tgt=<file> [options] VOCAB_FILE
Options:
-h --help Show this screen.
--train-src=<file> File of training source sentences
--train-tgt=<file> File of training target sentences
--size=<int> vocab size [default: 50000]
--freq-cutoff=<int> frequency cutoff [default: 2]
"""
from collections import Counter
from docopt import docopt
from itertools import chain
import json
import torch
from typing import List
from utils import read_corpus, pad_sents
class VocabEntry(object):
""" Vocabulary Entry, i.e. structure containing either
src or tgt language terms.
"""
def __init__(self, word2id=None):
""" Init VocabEntry Instance.
@param word2id (dict): dictionary mapping words 2 indices
"""
if word2id:
self.word2id = word2id
else:
self.word2id = dict()
self.word2id['<pad>'] = 0 # Pad Token
self.word2id['<s>'] = 1 # Start Token
self.word2id['</s>'] = 2 # End Token
self.word2id['<unk>'] = 3 # Unknown Token
self.unk_id = self.word2id['<unk>']
self.id2word = {v: k for k, v in self.word2id.items()}
def __getitem__(self, word):
""" Retrieve word's index. Return the index for the unk
token if the word is out of vocabulary.
@param word (str): word to look up.
@returns index (int): index of word
"""
return self.word2id.get(word, self.unk_id)
def __contains__(self, word):
""" Check if word is captured by VocabEntry.
@param word (str): word to look up
@returns contains (bool): whether word is contained
"""
return word in self.word2id
def __setitem__(self, key, value):
""" Raise error, if one tries to edit the VocabEntry.
"""
raise ValueError('vocabulary is readonly')
def __len__(self):
""" Compute number of words in VocabEntry.
@returns len (int): number of words in VocabEntry
"""
return len(self.word2id)
def __repr__(self):
""" Representation of VocabEntry to be used
when printing the object.
"""
return 'Vocabulary[size=%d]' % len(self)
def id2word(self, wid):
""" Return mapping of index to word.
@param wid (int): word index
@returns word (str): word corresponding to index
"""
return self.id2word[wid]
def add(self, word):
""" Add word to VocabEntry, if it is previously unseen.
@param word (str): word to add to VocabEntry
@return index (int): index that the word has been assigned
"""
if word not in self:
wid = self.word2id[word] = len(self)
self.id2word[wid] = word
return wid
else:
return self[word]
def words2indices(self, sents):
""" Convert list of words or list of sentences of words
into list or list of list of indices.
@param sents (list[str] or list[list[str]]): sentence(s) in words
@return word_ids (list[int] or list[list[int]]): sentence(s) in indices
"""
if type(sents[0]) == list:
return [[self[w] for w in s] for s in sents]
else:
return [self[w] for w in sents]
def indices2words(self, word_ids):
""" Convert list of indices into words.
@param word_ids (list[int]): list of word ids
@return sents (list[str]): list of words
"""
return [self.id2word[w_id] for w_id in word_ids]
def to_input_tensor(self, sents: List[List[str]], device: torch.device) -> torch.Tensor:
""" Convert list of sentences (words) into tensor with necessary padding for
shorter sentences.
@param sents (List[List[str]]): list of sentences (words)
@param device: device on which to load the tesnor, i.e. CPU or GPU
@returns sents_var: tensor of (max_sentence_length, batch_size)
"""
word_ids = self.words2indices(sents)
sents_t = pad_sents(word_ids, self['<pad>']) # num_sent x max_lens
sents_var = torch.tensor(sents_t, dtype=torch.long, device=device)
return torch.t(sents_var) # transpose, shape
@staticmethod
def from_corpus(corpus, size, freq_cutoff=5):
""" Given a corpus construct a Vocab Entry.
@param corpus (list[str]): corpus of text produced by read_corpus function
@param size (int): # of words in vocabulary
@param freq_cutoff (int): if word occurs n < freq_cutoff times, drop the word
@returns vocab_entry (VocabEntry): VocabEntry instance produced from provided corpus
"""
vocab_entry = VocabEntry()
word_freq = Counter(chain(*corpus))
valid_words = [w for w, v in word_freq.items() if v >= freq_cutoff]
print('number of word types: {}, number of word types w/ frequency >= {}: {}'
.format(len(word_freq), freq_cutoff, len(valid_words)))
top_k_words = sorted(valid_words, key=lambda w: word_freq[w], reverse=True)[:size]
for word in top_k_words:
vocab_entry.add(word)
return vocab_entry
class Vocab(object):
""" Vocab encapsulating src and target langauges.
"""
def __init__(self, src_vocab: VocabEntry, tgt_vocab: VocabEntry):
""" Init Vocab.
@param src_vocab (VocabEntry): VocabEntry for source language
@param tgt_vocab (VocabEntry): VocabEntry for target language
"""
self.src = src_vocab
self.tgt = tgt_vocab
@staticmethod
def build(src_sents, tgt_sents, vocab_size, freq_cutoff) -> 'Vocab':
""" Build Vocabulary.
@param src_sents (list[str]): Source sentences provided by read_corpus() function
@param tgt_sents (list[str]): Target sentences provided by read_corpus() function
@param vocab_size (int): Size of vocabulary for both source and target languages
@param freq_cutoff (int): if word occurs n < freq_cutoff times, drop the word.
"""
assert len(src_sents) == len(tgt_sents)
print('initialize source vocabulary ..')
src = VocabEntry.from_corpus(src_sents, vocab_size, freq_cutoff)
print('initialize target vocabulary ..')
tgt = VocabEntry.from_corpus(tgt_sents, vocab_size, freq_cutoff)
return Vocab(src, tgt)
def save(self, file_path):
""" Save Vocab to file as JSON dump.
@param file_path (str): file path to vocab file
"""
json.dump(dict(src_word2id=self.src.word2id, tgt_word2id=self.tgt.word2id), open(file_path, 'w'), indent=2)
@staticmethod
def load(file_path):
""" Load vocabulary from JSON dump.
@param file_path (str): file path to vocab file
@returns Vocab object loaded from JSON dump
"""
entry = json.load(open(file_path, 'r'))
src_word2id = entry['src_word2id']
tgt_word2id = entry['tgt_word2id']
return Vocab(VocabEntry(src_word2id), VocabEntry(tgt_word2id))
def __repr__(self):
""" Representation of Vocab to be used
when printing the object.
"""
return 'Vocab(source %d words, target %d words)' % (len(self.src), len(self.tgt))
if __name__ == '__main__':
args = docopt(__doc__)
print('read in source sentences: %s' % args['--train-src'])
print('read in target sentences: %s' % args['--train-tgt'])
src_sents = read_corpus(args['--train-src'], source='src')
tgt_sents = read_corpus(args['--train-tgt'], source='tgt')
vocab = Vocab.build(src_sents, tgt_sents, int(args['--size']), int(args['--freq-cutoff']))
print('generated vocabulary, source %d words, target %d words' % (len(vocab.src), len(vocab.tgt)))
vocab.save(args['VOCAB_FILE'])
print('vocabulary saved to %s' % args['VOCAB_FILE'])