This repository has been archived by the owner on Jan 13, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 51
/
tokenizer.lua
269 lines (250 loc) · 10.1 KB
/
tokenizer.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
--
-- Copyright (c) 2015, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
-- Author: Sumit Chopra <spchopra@fb.com>
-- Michael Mathieu <myrhev@fb.com>
-- Marc'Aurelio Ranzato <ranzato@fb.com>
-- Tomas Mikolov <tmikolov@fb.com>
-- Armand Joulin <ajoulin@fb.com>
-- Build the word based dataset for a text corpus. Words are clustered
-- based on their frequency to generate buckets with equal probability
require('math')
local ffivector = require('fb.ffivector')
local pl = require('pl.import_into')()
local Tokenizer = {}
function Tokenizer.build_dictionary(config, trainfname)
local kMaxDictSize = 500000
local dict = {}
dict.symbol_to_index = {} -- string -> id
dict.index_to_symbol = {} -- id -> string
dict.index_to_freq = torch.Tensor(kMaxDictSize) -- id ->freq
dict.index_to_cluster = nil -- id ->cluster_id
dict.index_to_index_within_cluster = nil
dict.cluster_to_index = {} -- reverse mapping from cluster to word id.
dict.mapping = nil -- cluster_id to within_cluster_id mapping used by hsm
local nr_clusters = config.nclusters
local threshold = config.threshold
print("[loading: " .. trainfname .. "]")
local nr_words = 1 -- number of unique words
local tot_nr_words = 0 -- total number of words in corpus
-- Add by default an UNK token to be used for the rare entries
local unk = "<UNK>"
dict.symbol_to_index[unk] = nr_words
dict.index_to_symbol[nr_words] = unk
dict.index_to_freq[nr_words] = 1
local cnt = 0
for s in io.lines(trainfname) do
-- remove all the tabs in the string
s = s:gsub("\t", "")
-- convert multiple spaces into a single space: this is needed to
-- make the following pl.utils.split() function return only words
-- and not white spaes
s = s:gsub("%s+", " ")
local words = pl.utils.split(s, ' ')
for i, word in pairs(words) do
if word ~= "" then -- somehow the first token is always ""
if dict.symbol_to_index[word] == nil then
nr_words = nr_words + 1
dict.symbol_to_index[word] = nr_words
dict.index_to_symbol[nr_words] = word
dict.index_to_freq[nr_words] = 1
else
local indx = dict.symbol_to_index[word]
dict.index_to_freq[indx] = dict.index_to_freq[indx] + 1
end
cnt = cnt + 1
end
end
-- Add \n after every line
if dict.symbol_to_index["</s>"] == nil then
nr_words = nr_words + 1
dict.symbol_to_index["</s>"] = nr_words
dict.index_to_symbol[nr_words] = "</s>"
dict.index_to_freq[nr_words] = 1
else
local indx = dict.symbol_to_index["</s>"]
dict.index_to_freq[indx] = dict.index_to_freq[indx] + 1
end
cnt = cnt + 1
end
dict.index_to_freq:resize(nr_words)
tot_nr_words = dict.index_to_freq:sum()
print("[Done making the dictionary. There are " .. nr_words - 1 ..
" unique words and a total of " .. tot_nr_words - 1 ..
" words in the training set.]")
-- map rare words to special token and skip corresponding indices
-- if the specified threshold is greater than 0
local removed = 0
local net_nwords = 1
if threshold > 0 then
for i = 2, dict.index_to_freq:size(1) do
local word = dict.index_to_symbol[i]
if dict.index_to_freq[i] < threshold then
dict.index_to_freq[1] =
dict.index_to_freq[1] + dict.index_to_freq[i]
dict.index_to_freq[i] = 0
dict.symbol_to_index[word] = 1
removed = removed + 1
else
-- re-adjust the indices to make them continuous
net_nwords = net_nwords + 1
dict.index_to_freq[net_nwords] = dict.index_to_freq[i]
dict.symbol_to_index[word] = net_nwords
dict.index_to_symbol[net_nwords] = word
end
end
print('[Removed ' .. removed .. ' rare words. ' ..
'Effective number of words ' .. net_nwords .. ']')
dict.index_to_freq:resize(net_nwords)
else
net_nwords = nr_words
end
-- create the cluster index tensors
dict.index_to_cluster = torch.LongTensor(net_nwords):fill(0)
dict.index_to_index_within_cluster = torch.LongTensor(net_nwords):fill(0)
-- sort the tokens by frequency
local sorted_freqs, sorted_indx = torch.sort(dict.index_to_freq, true)
sorted_freqs:div(math.max(1, tot_nr_words))
if nr_clusters == 0 then
nr_clusters = math.floor(math.sqrt(net_nwords))
end
local probab_mass = 1.0 / nr_clusters
local current_mass = 0
local cluster_id = 1
local within_cluster_index = 0
for w = 1, net_nwords do
if current_mass < probab_mass then
current_mass = current_mass + sorted_freqs[w]
within_cluster_index = within_cluster_index + 1
else
cluster_id = cluster_id + 1
current_mass = sorted_freqs[w]
within_cluster_index = 1
end
dict.index_to_cluster[sorted_indx[w]] = cluster_id
dict.index_to_index_within_cluster[sorted_indx[w]] =
within_cluster_index
end
print("[Created " .. cluster_id .. " clusters.]")
-- Count how many words per cluster there are.
local wordsPerCluster = torch.zeros(cluster_id)
for w = 1, net_nwords do
local curr_cluster = dict.index_to_cluster[w]
wordsPerCluster[curr_cluster] = wordsPerCluster[curr_cluster] + 1
end
-- build reverse index from cluster id back to index
-- also load the explicit mapping to be used by hsm
dict.mapping = torch.LongTensor(net_nwords, 2)
for c = 1, cluster_id do
table.insert(dict.cluster_to_index,
torch.LongTensor(wordsPerCluster[c]))
end
for w = 1, net_nwords do
local curr_cluster = dict.index_to_cluster[w]
local curr_word = dict.index_to_index_within_cluster[w]
dict.cluster_to_index[curr_cluster][curr_word] = w
dict.mapping[w][1] = curr_cluster
dict.mapping[w][2] = curr_word
end
dict.separatorIndex = dict.symbol_to_index['</s>']
-- Save dictionary.
local dictfname = paths.concat(config.dest_path,
config.name .. '.dictionary' ..
'_nclust=' .. nr_clusters ..
'_thresh=' .. config.threshold ..
'.th7')
torch.save(dictfname, dict)
print('There are effectively ' .. net_nwords .. ' words in the corpus.')
-- return the dictionary
return dict, nr_clusters
end
-- This function tokenizes the data (converts words to word_ids)
-- and stores the result in a 1D longTensor
-- Inputs:
-- dict: dictionary
-- filenameIn: full path of the input file
-- filenameOut: full path of the output file
-- config: configuration parameters of the data
-- shuff: whether to shuffle the sentences or not (true|false)
function Tokenizer.tokenize(dict, filenameIn, filenameOut, config, eos)
print("saving to " .. filenameOut)
local unk = "<UNK>"
local threshold = config.threshold
local eos = config.eos
-- first count how many words there are in the corpus
local all_lines = ffivector.new_string()
local tot_nr_words = 0
local tot_lines = 0
for s in io.lines(filenameIn) do
-- store the line
tot_lines = tot_lines + 1
all_lines[tot_lines] = s
-- remove all the tabs in the string
s = s:gsub("\t", "")
-- remove leading and following white spaces
s = s:gsub("^%s+", ""):gsub("%s+$", "")
-- convert multiple spaces into a single space: this is needed to
-- make the following pl.utils.split() function return only words
-- and not white spaes
s = s:gsub("%s+", " ")
-- count the words
local words = pl.utils.split(s, ' ')
tot_nr_words = tot_nr_words + #words -- nr. words in the line
tot_nr_words = tot_nr_words + 1 -- newline
end
print('-- total lines: ' .. tot_lines)
-- get the permutation vector
-- perm_vec
local perm_vec
if shuff == true then
print('-- shuffling the data')
perm_vec = torch.randperm(tot_lines)
else
print('-- not shuffling the data')
perm_vec = torch.range(1, tot_lines)
end
-- now store the lines in the tensor
local data = torch.Tensor(tot_nr_words) -- id, cluster_id, within_cluster_id
local id = 0
local cnt = 1
for ln = 1, tot_lines do
local s = all_lines[perm_vec[ln]]
-- remove all the tabs in the string
s = s:gsub("\t", "")
-- remove leading and following white spaces
s = s:gsub("^%s+", ""):gsub("%s+$", "")
-- convert multiple spaces into a single space: this is needed to
-- make the following pl.utils.split() function return only words
-- and not white spaes
s = s:gsub("%s+", " ")
collectgarbage()
local words = pl.utils.split(s, ' ')
for i, word in pairs(words) do
if word ~= "" then
if dict.symbol_to_index[word] == nil or
dict.index_to_freq[dict.symbol_to_index[word]] < threshold then
print('WARNING: ' .. word .. ' being replaced by ' .. unk)
id = dict.symbol_to_index[unk]
else
id = dict.symbol_to_index[word]
end
data[cnt] = id
cnt = cnt + 1
end
end
-- Add newline if specified
if eos == true then
id = dict.symbol_to_index["</s>"]
data[cnt] = id
cnt = cnt + 1
end
collectgarbage()
end
torch.save(filenameOut, data)
end
return Tokenizer