-
Notifications
You must be signed in to change notification settings - Fork 561
/
torchinputconns.cc
234 lines (192 loc) · 6.65 KB
/
torchinputconns.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
#include "torchinputconns.h"
namespace dd {
using namespace torch;
// ===== TorchDataset
void TorchDataset::add_batch(std::vector<at::Tensor> data, std::vector<at::Tensor> target)
{
_batches.push_back(TorchBatch(data, target));
}
void TorchDataset::reset()
{
_indices.clear();
for (int64_t i = 0; i < _batches.size(); ++i) {
_indices.push_back(i);
}
if (_shuffle)
{
auto seed = _seed == -1 ? static_cast<long>(time(NULL)) : _seed;
std::shuffle(_indices.begin(), _indices.end(), std::mt19937(seed));
}
}
c10::optional<TorchBatch> TorchDataset::get_batch(BatchRequestType request)
{
size_t count = request[0];
count = count < _indices.size() ? count : _indices.size();
if (count == 0) {
return torch::nullopt;
}
std::vector<std::vector<Tensor>> data, target;
while(count != 0) {
auto id = _indices.back();
auto entry = _batches[id];
for (int i = 0; i < entry.data.size(); ++i)
{
while (i >= data.size())
data.emplace_back();
data[i].push_back(entry.data.at(i));
}
for (int i = 0; i < entry.target.size(); ++i)
{
while (i >= target.size())
target.emplace_back();
target[i].push_back(entry.target.at(i));
}
_indices.pop_back();
count--;
}
std::vector<Tensor> data_tensors;
for (auto vec : data)
data_tensors.push_back(torch::stack(vec));
std::vector<Tensor> target_tensors;
for (auto vec : target)
target_tensors.push_back(torch::stack(vec));
return TorchBatch{ data_tensors, target_tensors };
}
TorchBatch TorchDataset::get_cached() {
reset();
auto batch = get_batch({cache_size()});
if (!batch)
throw InputConnectorInternalException("No data provided");
return batch.value();
}
TorchDataset TorchDataset::split(double start, double stop)
{
auto datasize = _batches.size();
auto start_it = _batches.begin() + static_cast<int64_t>(datasize * start);
auto stop_it = _batches.end() - static_cast<int64_t>(datasize * (1 - stop));
TorchDataset new_dataset;
new_dataset._batches.insert(new_dataset._batches.end(), start_it, stop_it);
return new_dataset;
}
// ===== TxtTorchInputFileConn
void TxtTorchInputFileConn::fillup_parameters(const APIData &ad_input)
{
if (ad_input.has("width"))
_width = ad_input.get("width").get<int>();
}
void TxtTorchInputFileConn::transform(const APIData &ad) {
// if (_finetuning)
// XXX: Generating vocab from scratch is not currently
_generate_vocab = false;
try
{
TxtInputFileConn::transform(ad);
}
catch(const std::exception& e)
{
throw;
}
if (!_ordered_words || _characters)
throw InputConnectorBadParamException("Need ordered_words = true with backend torch");
if (ad.has("parameters") && ad.getobj("parameters").has("input"))
{
APIData ad_input = ad.getobj("parameters").getobj("input");
fillup_parameters(ad_input);
}
_cls_pos = _vocab.at("[CLS]")._pos;
_sep_pos = _vocab.at("[SEP]")._pos;
_unk_pos = _vocab.at("[UNK]")._pos;
_mask_id = _vocab.at("[MASK]")._pos;
fill_dataset(_dataset, _txt);
if (!_test_txt.empty())
fill_dataset(_test_dataset, _test_txt);
}
TorchBatch TxtTorchInputFileConn::generate_masked_lm_batch(const TorchBatch &example)
{
std::uniform_real_distribution<double> uniform(0, 1);
std::uniform_int_distribution<int64_t> vocab_distrib(0, vocab_size() - 1);
Tensor input_ids = example.data.at(0).clone();
Tensor lm_labels = torch::ones_like(input_ids, TensorOptions(kLong)) * -1;
// mask random tokens
auto input_acc = input_ids.accessor<int64_t,2>();
auto att_mask_acc = example.data.at(2).accessor<int64_t,2>();
auto labels_acc = lm_labels.accessor<int64_t,2>();
for (int i = 0; i < input_ids.size(0); ++i)
{
int j = 1; // skip [CLS] token
while (j < input_ids.size(1) && att_mask_acc[i][j] != 0)
{
double rand_num = uniform(_rng);
if (rand_num < _lm_params._change_prob && input_acc[i][j] != _sep_pos)
{
labels_acc[i][j] = input_acc[i][j];
rand_num = uniform(_rng);
if (rand_num < _lm_params._mask_prob)
{
input_acc[i][j] = mask_id();
}
else if (rand_num < _lm_params._mask_prob + _lm_params._rand_prob)
{
input_acc[i][j] = vocab_distrib(_rng);
}
}
++j;
}
}
TorchBatch output;
output.target.push_back(lm_labels);
output.data.push_back(input_ids);
for (int i = 1; i < example.data.size(); ++i)
{
output.data.push_back(example.data[i]);
}
return output;
}
void TxtTorchInputFileConn::fill_dataset(TorchDataset &dataset,
const std::vector<TxtEntry<double>*> &entries)
{
for (auto *te : entries)
{
TxtOrderedWordsEntry *tow = static_cast<TxtOrderedWordsEntry *>(te);
tow->reset();
std::vector<int64_t> ids;
ids.push_back(_cls_pos);
while(tow->has_elt())
{
if (ids.size() >= _width - 1)
break;
std::string word;
double val;
tow->get_next_elt(word, val);
std::unordered_map<std::string,Word>::iterator it;
if ((it = _vocab.find(word)) != _vocab.end())
{
ids.push_back(it->second._pos);
}
else
{
ids.push_back(_unk_pos);
}
}
ids.push_back(_sep_pos);
at::Tensor ids_tensor = toLongTensor(ids);
at::Tensor mask_tensor = torch::ones_like(ids_tensor);
at::Tensor token_type_ids_tensor = torch::zeros_like(ids_tensor);
int64_t padding_size = _width - ids_tensor.sizes().back();
ids_tensor = torch::constant_pad_nd(
ids_tensor, at::IntList{0, padding_size}, 0);
mask_tensor = torch::constant_pad_nd(
mask_tensor, at::IntList{0, padding_size}, 0);
token_type_ids_tensor = torch::constant_pad_nd(
token_type_ids_tensor, at::IntList{0, padding_size}, 0);
std::vector<Tensor> target_vec;
int target_val = static_cast<int>(tow->_target);
if (target_val != -1)
{
Tensor target_tensor = torch::full(1, target_val, torch::kLong);
target_vec.push_back(target_tensor);
}
dataset.add_batch({ids_tensor, token_type_ids_tensor, mask_tensor}, std::move(target_vec));
}
}
}