From 121831e63fd43b78a1316ef50b20d6e5dd6bd019 Mon Sep 17 00:00:00 2001 From: Louis J Date: Wed, 24 Jul 2019 18:07:57 +0200 Subject: [PATCH 01/14] Move dataset management and model building in separate classes --- src/backends/torch/torchinputconns.cc | 90 +++++++++++++++++++++++---- src/backends/torch/torchinputconns.h | 46 +++++++++++--- src/backends/torch/torchlib.cc | 89 +++++++++++++++++++------- src/backends/torch/torchlib.h | 22 ++++++- 4 files changed, 204 insertions(+), 43 deletions(-) diff --git a/src/backends/torch/torchinputconns.cc b/src/backends/torch/torchinputconns.cc index b5950f5ac..c1b100ea9 100644 --- a/src/backends/torch/torchinputconns.cc +++ b/src/backends/torch/torchinputconns.cc @@ -2,6 +2,81 @@ namespace dd { +using namespace torch; + + +void TorchDataset::add_batch(std::vector data, std::vector target) +{ + _batches.push_back(TorchBatch(data, target)); +} + +void TorchDataset::reset() +{ + _indices.clear(); + + for (int64_t i = 0; i < _batches.size(); ++i) { + _indices.push_back(i); + } + + if (_shuffle) + { + auto seed = _seed == -1 ? static_cast(time(NULL)) : _seed; + std::shuffle(_indices.begin(), _indices.end(), std::mt19937(seed)); + } +} + +c10::optional TorchDataset::get_batch(BatchRequestType request) +{ + size_t count = request[0]; + count = count < _indices.size() ? count : _indices.size(); + + if (count == 0) { + return torch::nullopt; + } + + std::vector> data, target; + + while(count != 0) { + auto id = _indices.back(); + auto entry = _batches[id]; + + for (int i = 0; i < entry.data.size(); ++i) + { + while (i >= data.size()) + data.emplace_back(); + data[i].push_back(entry.data.at(i)); + } + for (int i = 0; i < entry.target.size(); ++i) + { + while (i >= target.size()) + target.emplace_back(); + target[i].push_back(entry.target.at(i)); + } + + _indices.pop_back(); + count--; + } + + std::vector data_tensors; + for (auto vec : data) + data_tensors.push_back(torch::stack(vec)); + + std::vector target_tensors; + for (auto vec : target) + target_tensors.push_back(torch::stack(vec)); + + return TorchBatch{ data_tensors, target_tensors }; +} + +TorchBatch TorchDataset::get_cached() { + reset(); + auto batch = get_batch({cache_size()}); + if (!batch) + throw InputConnectorInternalException("No data provided"); + return batch.value(); +} + + void TxtTorchInputFileConn::transform(const APIData &ad) { try { @@ -26,9 +101,6 @@ void TxtTorchInputFileConn::transform(const APIData &ad) { int sep_pos = _vocab.at("[SEP]")._pos; int unk_pos = _vocab.at("[UNK]")._pos; - std::vector vids; - std::vector vmask; - for (auto *te : _txt) { TxtOrderedWordsEntry *tow = static_cast(te); @@ -58,22 +130,18 @@ void TxtTorchInputFileConn::transform(const APIData &ad) { at::Tensor ids_tensor = toLongTensor(ids); at::Tensor mask_tensor = torch::ones_like(ids_tensor); - // at::Tensor token_type_ids_tensor = torch::zeros_like(ids_tensor); + at::Tensor token_type_ids_tensor = torch::zeros_like(ids_tensor); int64_t padding_size = _width - ids_tensor.sizes().back(); ids_tensor = torch::constant_pad_nd( ids_tensor, at::IntList{0, padding_size}, 0); mask_tensor = torch::constant_pad_nd( mask_tensor, at::IntList{0, padding_size}, 0); - // token_type_ids_tensor = torch::constant_pad_nd( - // token_type_ids_tensor, at::IntList{0, padding_size}, 0); + token_type_ids_tensor = torch::constant_pad_nd( + token_type_ids_tensor, at::IntList{0, padding_size}, 0); - vids.push_back(ids_tensor); - vmask.push_back(mask_tensor); + _dataset.add_batch({ids_tensor, token_type_ids_tensor, mask_tensor}, {}); } - - _in = torch::stack(vids, 0); - _attention_mask = torch::stack(vmask, 0); } } \ No newline at end of file diff --git a/src/backends/torch/torchinputconns.h b/src/backends/torch/torchinputconns.h index 583978f30..303609c1a 100644 --- a/src/backends/torch/torchinputconns.h +++ b/src/backends/torch/torchinputconns.h @@ -31,12 +31,46 @@ namespace dd { + typedef torch::data::Example, std::vector> TorchBatch; + + class TorchDataset : public torch::data::BatchDataset + > + { + private: + bool _shuffle = false; + long _seed = -1; + std::vector _indices; + + public: + std::vector _batches; + + + TorchDataset() {} + + void add_batch(std::vector data, std::vector target = {}); + + void reset(); + + // Size of data loaded in memory + size_t cache_size() const { return _batches.size(); } + + c10::optional size() const override { + return cache_size(); + } + + c10::optional get_batch(BatchRequestType request) override; + + // Returns a batch containing all the cached data + TorchBatch get_cached(); + }; + + + class TorchInputInterface { public: TorchInputInterface() {} - TorchInputInterface(const TorchInputInterface &i) - : _in(i._in), _attention_mask(i._attention_mask) {} + TorchInputInterface(const TorchInputInterface &i) {} ~TorchInputInterface() {} @@ -45,8 +79,8 @@ namespace dd return torch::from_blob(&values[0], at::IntList{val_size}, at::kLong); } - at::Tensor _in; - at::Tensor _attention_mask; + TorchDataset _dataset; + TorchDataset _test_dataset; }; class ImgTorchInputFileConn : public ImgInputFileConn, public TorchInputInterface @@ -55,7 +89,7 @@ namespace dd ImgTorchInputFileConn() :ImgInputFileConn() {} ImgTorchInputFileConn(const ImgTorchInputFileConn &i) - :ImgInputFileConn(i),TorchInputInterface(i) {} + :ImgInputFileConn(i),TorchInputInterface(i), _std{i._std} {} ~ImgTorchInputFileConn() {} // for API info only @@ -108,8 +142,6 @@ namespace dd imgt = imgt.mul(1. / _std); tensors.push_back(imgt); } - - _in = torch::stack(tensors, 0); } public: diff --git a/src/backends/torch/torchlib.cc b/src/backends/torch/torchlib.cc index b76b62fdc..42c4fa680 100644 --- a/src/backends/torch/torchlib.cc +++ b/src/backends/torch/torchlib.cc @@ -29,6 +29,64 @@ using namespace torch; namespace dd { + // ======= TORCH MODULE + + + void add_parameters(std::shared_ptr module, std::vector ¶ms) { + for (const auto &slot : module->get_parameters()) { + params.push_back(slot.value().toTensor()); + } + for (auto child : module->get_modules()) { + add_parameters(child, params); + } + } + + Tensor to_tensor_safe(const IValue &value) { + if (!value.isTensor()) + throw MLLibInternalException("Model returned an invalid output. Please check your model."); + return value.toTensor(); + } + + c10::IValue TorchModule::forward(std::vector source) + { + if (_traced) + { + source = { _traced->forward(source) }; + } + c10::IValue out_val = source.at(0); + if (_classif) + { + out_val = _classif->forward(to_tensor_safe(out_val)); + } + return out_val; + } + + std::vector TorchModule::parameters() + { + std::vector params; + if (_traced) + add_parameters(_traced, params); + if (_classif) + { + auto classif_params = _classif->parameters(); + params.insert(params.end(), classif_params.begin(), classif_params.end()); + } + return params; + } + + void TorchModule::save(const std::string &filename) + { + // Not yet implemented + } + + void TorchModule::load(const std::string &filename) + { + // Not yet implemented + } + + + // ======= TORCHLIB + template TorchLib::TorchLib(const TorchModel &tmodel) : MLLib(tmodel) @@ -41,10 +99,9 @@ namespace dd : MLLib(std::move(tl)) { this->_libname = "torch"; - _traced = std::move(tl._traced); + _module = std::move(tl._module); _nclasses = tl._nclasses; _device = tl._device; - _attention = tl._attention; } template @@ -72,11 +129,11 @@ namespace dd _device = gpu ? torch::Device(DeviceType::CUDA, gpuid) : torch::Device(DeviceType::CPU); if (typeid(TInputConnectorStrategy) == typeid(TxtTorchInputFileConn)) { - _attention = true; + //_attention = true; } - _traced = torch::jit::load(this->_mlmodel._model_file, _device); - _traced->eval(); + _module._traced = torch::jit::load(this->_mlmodel._model_file, _device); + _module._traced->eval(); } template @@ -106,32 +163,19 @@ namespace dd } torch::Device cpu("cpu"); + inputc._dataset.reset(); std::vector in_vals; - in_vals.push_back(inputc._in.to(_device)); - - if (_attention) { - // token_type_ids - in_vals.push_back(torch::zeros_like(inputc._in, at::kLong).to(_device)); - in_vals.push_back(inputc._attention_mask.to(_device)); - } - + for (Tensor tensor : inputc._dataset.get_cached().data) + in_vals.push_back(tensor.to(_device)); Tensor output; try { - c10::IValue out_val = _traced->forward(in_vals); - if (out_val.isTuple()) { - out_val = out_val.toTuple()->elements()[0]; - } - if (!out_val.isTensor()) { - throw MLLibInternalException("Model returned an invalid output. Please check your model."); - } - output = out_val.toTensor().to(at::kFloat); + output = torch::softmax(to_tensor_safe(_module.forward(in_vals)), 1); } catch (std::exception &e) { throw MLLibInternalException(std::string("Libtorch error:") + e.what()); } - output = torch::softmax(output, 1).to(cpu); // Output std::vector results_ads; @@ -174,7 +218,6 @@ namespace dd return 0; } - template class TorchLib; template class TorchLib; } diff --git a/src/backends/torch/torchlib.h b/src/backends/torch/torchlib.h index ddb6809b5..b82b632b1 100644 --- a/src/backends/torch/torchlib.h +++ b/src/backends/torch/torchlib.h @@ -32,6 +32,21 @@ namespace dd { + class TorchModule { + public: + c10::IValue forward(std::vector source); + + std::vector parameters(); + + void save(const std::string &filename); + + void load(const std::string &filename); + public: + std::shared_ptr _traced; + torch::nn::Linear _classif = nullptr; + }; + + template class TorchLib : public MLLib { @@ -49,11 +64,14 @@ namespace dd int predict(const APIData &ad, APIData &out); + int test(const APIData &ad, APIData &out); + public: int _nclasses = 0; - bool _attention = false; torch::Device _device = torch::Device("cpu"); - std::shared_ptr _traced; + + // models + TorchModule _module; }; } From 02ecf8278da7a376ea424b99e94c94bb5a1a0645 Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 6 Sep 2019 16:00:03 +0200 Subject: [PATCH 02/14] Add train and test The fix on txtinputconnector is temporary, vocab generation should be fixed a more robust way --- src/backends/torch/torchinputconns.cc | 33 +++- src/backends/torch/torchinputconns.h | 6 + src/backends/torch/torchlib.cc | 226 +++++++++++++++++++++++++- src/backends/torch/torchlib.h | 9 +- src/backends/torch/torchmodel.cc | 3 + src/txtinputfileconn.cc | 5 +- src/txtinputfileconn.h | 2 +- 7 files changed, 267 insertions(+), 17 deletions(-) diff --git a/src/backends/torch/torchinputconns.cc b/src/backends/torch/torchinputconns.cc index c1b100ea9..ed916a46c 100644 --- a/src/backends/torch/torchinputconns.cc +++ b/src/backends/torch/torchinputconns.cc @@ -76,6 +76,17 @@ TorchBatch TorchDataset::get_cached() { return batch.value(); } +TorchDataset TorchDataset::split(double start, double stop) +{ + auto datasize = _batches.size(); + auto start_it = _batches.begin() + static_cast(datasize * start); + auto stop_it = _batches.end() - static_cast(datasize * (1 - stop)); + + TorchDataset new_dataset; + new_dataset._batches.insert(new_dataset._batches.end(), start_it, stop_it); + return new_dataset; +} + void TxtTorchInputFileConn::transform(const APIData &ad) { try @@ -97,11 +108,19 @@ void TxtTorchInputFileConn::transform(const APIData &ad) { if (!_ordered_words || _characters) throw InputConnectorBadParamException("Need ordered_words = true with backend torch"); + fill_dataset(_dataset, _txt); + if (!_test_txt.empty()) + fill_dataset(_test_dataset, _test_txt); +} + +void TxtTorchInputFileConn::fill_dataset(TorchDataset &dataset, + const std::vector*> &entries) +{ int cls_pos = _vocab.at("[CLS]")._pos; int sep_pos = _vocab.at("[SEP]")._pos; int unk_pos = _vocab.at("[UNK]")._pos; - for (auto *te : _txt) + for (auto *te : entries) { TxtOrderedWordsEntry *tow = static_cast(te); tow->reset(); @@ -140,8 +159,16 @@ void TxtTorchInputFileConn::transform(const APIData &ad) { token_type_ids_tensor = torch::constant_pad_nd( token_type_ids_tensor, at::IntList{0, padding_size}, 0); - _dataset.add_batch({ids_tensor, token_type_ids_tensor, mask_tensor}, {}); + std::vector target_vec; + int target_val = static_cast(tow->_target); + if (target_val != -1) + { + Tensor target_tensor = torch::full(1, target_val, torch::kLong); + target_vec.push_back(target_tensor); + } + + dataset.add_batch({ids_tensor, token_type_ids_tensor, mask_tensor}, std::move(target_vec)); } } -} \ No newline at end of file +} diff --git a/src/backends/torch/torchinputconns.h b/src/backends/torch/torchinputconns.h index 303609c1a..3ba2c3814 100644 --- a/src/backends/torch/torchinputconns.h +++ b/src/backends/torch/torchinputconns.h @@ -58,10 +58,15 @@ namespace dd return cache_size(); } + bool empty() const { return cache_size() == 0; } + c10::optional get_batch(BatchRequestType request) override; // Returns a batch containing all the cached data TorchBatch get_cached(); + + // Split a percentage of this dataset + TorchDataset split(double start, double stop); }; @@ -177,6 +182,7 @@ namespace dd void transform(const APIData &ad); + void fill_dataset(TorchDataset &dataset, const std::vector*> &entries); public: int _width = 512; int _height = 0; diff --git a/src/backends/torch/torchlib.cc b/src/backends/torch/torchlib.cc index 42c4fa680..b052c5137 100644 --- a/src/backends/torch/torchlib.cc +++ b/src/backends/torch/torchlib.cc @@ -29,9 +29,6 @@ using namespace torch; namespace dd { - // ======= TORCH MODULE - - void add_parameters(std::shared_ptr module, std::vector ¶ms) { for (const auto &slot : module->get_parameters()) { params.push_back(slot.value().toTensor()); @@ -41,11 +38,37 @@ namespace dd } } + /// Convert IValue to Tensor and throw an exception if the IValue is not a Tensor. Tensor to_tensor_safe(const IValue &value) { if (!value.isTensor()) - throw MLLibInternalException("Model returned an invalid output. Please check your model."); + throw MLLibInternalException("Expected Tensor, found " + value.tagKind()); return value.toTensor(); } + + /// Convert id Tensor to one_hot Tensor + void fill_one_hot(Tensor &one_hot, Tensor ids, int nclasses) + { + one_hot.zero_(); + for (int i = 0; i < ids.size(0); ++i) + { + one_hot[i][ids[i].item()] = 1; + } + } + + Tensor to_one_hot(Tensor ids, int nclasses) + { + Tensor one_hot = torch::zeros(IntList{ids.size(0), nclasses}); + for (int i = 0; i < ids.size(0); ++i) + { + one_hot[i][ids[i].item()] = 1; + } + return one_hot; + } + + // ======= TORCH MODULE + + + TorchModule::TorchModule() {} c10::IValue TorchModule::forward(std::vector source) { @@ -74,12 +97,13 @@ namespace dd return params; } - void TorchModule::save(const std::string &filename) + void TorchModule::save_checkpoint(TorchModel &model, const std::string &name) { - // Not yet implemented + if (_traced) + _traced->save(model._repo + "/checkpoint-" + name + ".pt"); } - void TorchModule::load(const std::string &filename) + void TorchModule::load_checkpoint(const std::string &filename) { // Not yet implemented } @@ -132,6 +156,7 @@ namespace dd //_attention = true; } + // TODO load classification layer, or create it according to the number of classes _module._traced = torch::jit::load(this->_mlmodel._model_file, _device); _module._traced->eval(); } @@ -145,7 +170,129 @@ namespace dd template int TorchLib::train(const APIData &ad, APIData &out) { - + TInputConnectorStrategy inputc(this->_inputc); + inputc._train = true; + + try + { + inputc.transform(ad); + } + catch (...) + { + throw; + } + + APIData ad_mllib = ad.getobj("parameters").getobj("mllib"); + + // solver params + int64_t iterations = 100; + std::string solver_type = "SGD"; + double base_lr = 0.0001; + int64_t batch_size = 5; + int64_t test_batch_size = 1; + int64_t test_interval = 1; + int64_t save_period = 0; + int64_t batch_count = (inputc._dataset.cache_size() - 1) / batch_size + 1; + + // logging parameters + int64_t log_batch_period = 1; + + if (ad_mllib.has("iterations")) + iterations = ad_mllib.get("iterations").get(); + if (ad_mllib.has("solver_type")) + solver_type = ad_mllib.get("solver_type").get(); + if (ad_mllib.has("base_lr")) + base_lr = ad_mllib.get("base_lr").get(); + if (ad_mllib.has("test_interval")) + test_interval = ad_mllib.get("test_interval").get(); + if (ad_mllib.has("batch_size")) + batch_size = ad_mllib.get("batch_size").get(); + if (ad_mllib.has("test_batch_size")) + test_batch_size = ad_mllib.get("test_batch_size").get(); + if (ad_mllib.has("save_period")) + save_period = ad_mllib.get("save_period").get(); + + // create dataset for evaluation during training + TorchDataset eval_dataset; + if (!inputc._test_dataset.empty()) + { + eval_dataset = inputc._test_dataset; //.split(0, 0.1); + } + + // create solver + // no care about solver type yet + optim::Adam optimizer(_module.parameters(), optim::AdamOptions(base_lr)); + + // create dataloader + auto dataloader = torch::data::make_data_loader( + std::move(inputc._dataset), + data::DataLoaderOptions(batch_size) + ); + this->_logger->info("Training for {} iterations", iterations); + + for (int64_t epoch = 0; epoch < iterations; ++epoch) + { + this->add_meas("iteration", epoch); + this->_logger->info("Iteration {}", epoch); + int batch_id = 0; + + for (TorchBatch &example : *dataloader) + { + std::vector in_vals; + for (Tensor tensor : example.data) + in_vals.push_back(tensor.to(_device)); + Tensor y_pred = to_tensor_safe(_module.forward(in_vals)); + Tensor y = to_one_hot(example.target.at(0), _nclasses).to(_device); + + // TODO let loss be a parameter + auto loss = torch::mse_loss(y_pred, y); + double loss_val = loss.item(); + + optimizer.zero_grad(); + loss.backward(); + optimizer.step(); + + this->add_meas("train_loss", loss_val); + this->add_meas_per_iter("train_loss", loss_val); + if (log_batch_period != 0 && (batch_id + 1) % log_batch_period == 0) + { + this->_logger->info("Batch {}/{}: loss is {}", batch_id + 1, batch_count, loss_val); + } + ++batch_id; + } + + if (epoch > 0 && epoch % test_interval == 0 && !eval_dataset.empty()) + { + APIData meas_out; + test(ad, eval_dataset, test_batch_size, meas_out); + eval_dataset = inputc._test_dataset; // XXX eval_dataset is moved. find a way to unmove it + APIData meas_obj = meas_out.getobj("measure"); + std::vector meas_names = meas_obj.list_keys(); + + for (auto name : meas_names) + { + double mval = meas_obj.get(name).get(); + this->_logger->info("{}={}", name, mval); + this->add_meas(name, mval); + this->add_meas_per_iter(name, mval); + } + } + + int64_t check_id = epoch + 1; + if ((save_period != 0 && check_id % save_period == 0) || check_id == iterations) + { + this->_logger->info("Saving checkpoint after {} iterations", check_id); + _module.save_checkpoint(this->_mlmodel, std::to_string(check_id)); + } + } + + test(ad, inputc._test_dataset, test_batch_size, out); + + // TODO make model ready for predict after training + + inputc.response_params(out); + this->_logger->info("Training done."); + return 0; } template @@ -163,6 +310,15 @@ namespace dd } torch::Device cpu("cpu"); + if (output_params.has("measure")) + { + APIData meas_out; + test(ad, inputc._dataset, 1, meas_out); + meas_out.erase("iteration"); + out.add("measure", meas_out.getobj("measure")); + return 0; + } + inputc._dataset.reset(); std::vector in_vals; for (Tensor tensor : inputc._dataset.get_cached().data) @@ -218,6 +374,60 @@ namespace dd return 0; } + template + int TorchLib::test(const APIData &ad, + TorchDataset &dataset, + int batch_size, + APIData &out) + { + APIData ad_res; + APIData ad_out = ad.getobj("parameters").getobj("output"); + int test_size = dataset.cache_size(); + + // std::move may lead to unexpected behaviour from the input connector + auto dataloader = torch::data::make_data_loader( + std::move(dataset), + data::DataLoaderOptions(batch_size) + ); + + int entry_id = 0; + for (TorchBatch &batch : *dataloader) + { + std::vector in_vals; + for (Tensor tensor : batch.data) + in_vals.push_back(tensor.to(_device)); + Tensor output = torch::softmax(to_tensor_safe(_module.forward(in_vals)), 1); + if (batch.target.empty()) + throw MLLibBadParamException("Missing label on data while testing"); + Tensor labels = batch.target[0]; + + for (int j = 0; j < labels.size(0); ++j) { + APIData bad; + std::vector predictions; + for (int c = 0; c < _nclasses; c++) + { + predictions.push_back(output[j][c].item()); + } + bad.add("target", labels[j].item()); + bad.add("pred", predictions); + ad_res.add(std::to_string(entry_id), bad); + ++entry_id; + } + this->_logger->info("Testing: {}/{} entries processed", entry_id, test_size); + } + + ad_res.add("iteration",this->get_meas("iteration")); + ad_res.add("train_loss",this->get_meas("train_loss")); + std::vector clnames; + for (int i=0;i<_nclasses;i++) + clnames.push_back(this->_mlmodel.get_hcorresp(i)); + ad_res.add("clnames", clnames); + ad_res.add("nclasses", _nclasses); + ad_res.add("batch_size", entry_id); // here batch_size = tested entries count + SupervisedOutput::measure(ad_res, ad_out, out); + return 0; + } + template class TorchLib; template class TorchLib; } diff --git a/src/backends/torch/torchlib.h b/src/backends/torch/torchlib.h index b82b632b1..9bf0c9ab2 100644 --- a/src/backends/torch/torchlib.h +++ b/src/backends/torch/torchlib.h @@ -34,13 +34,15 @@ namespace dd { class TorchModule { public: + TorchModule(); + c10::IValue forward(std::vector source); std::vector parameters(); - void save(const std::string &filename); + void save_checkpoint(TorchModel &model, const std::string &name); - void load(const std::string &filename); + void load_checkpoint(const std::string &name); public: std::shared_ptr _traced; torch::nn::Linear _classif = nullptr; @@ -64,7 +66,8 @@ namespace dd int predict(const APIData &ad, APIData &out); - int test(const APIData &ad, APIData &out); + int test(const APIData &ad, TorchDataset &dataset, + int batch_size, APIData &out); public: int _nclasses = 0; diff --git a/src/backends/torch/torchmodel.cc b/src/backends/torch/torchmodel.cc index 1c5ceb234..37e5e5e81 100644 --- a/src/backends/torch/torchmodel.cc +++ b/src/backends/torch/torchmodel.cc @@ -43,6 +43,9 @@ namespace dd } } + // TODO detect and list checkpoints names + // detect last checkpoint + // detect classification layer for BERT return 0; } } \ No newline at end of file diff --git a/src/txtinputfileconn.cc b/src/txtinputfileconn.cc index 0288a1954..143be0015 100644 --- a/src/txtinputfileconn.cc +++ b/src/txtinputfileconn.cc @@ -208,8 +208,9 @@ namespace dd } // post-processing + // XXX: might want to use post-processing with ordered_words too size_t initial_vocab_size = _ctfc->_vocab.size(); - if (_ctfc->_train && !test_dir) + if (!_ctfc->_ordered_words && _ctfc->_train && !test_dir) { auto vhit = _ctfc->_vocab.begin(); while(vhit!=_ctfc->_vocab.end()) @@ -232,7 +233,7 @@ namespace dd } } - if (!_ctfc->_characters && !test_dir && (initial_vocab_size != _ctfc->_vocab.size() || _ctfc->_tfidf)) + if (!_ctfc->_ordered_words && !_ctfc->_characters && !test_dir && (initial_vocab_size != _ctfc->_vocab.size() || _ctfc->_tfidf)) { // clearing up the corpus + tfidf std::unordered_map::iterator whit; diff --git a/src/txtinputfileconn.h b/src/txtinputfileconn.h index 4c5bf7945..a7fc16907 100644 --- a/src/txtinputfileconn.h +++ b/src/txtinputfileconn.h @@ -347,7 +347,7 @@ namespace dd if (_alphabet.empty() && _characters) build_alphabet(); - if (!_characters && !_train && _vocab.empty()) + if (!_characters && (!_train || _ordered_words) && _vocab.empty()) deserialize_vocab(); for (std::string u: _uris) From 150781bb3b57a8765ab464fe4e5fee3da2404d07 Mon Sep 17 00:00:00 2001 From: Louis J Date: Thu, 8 Aug 2019 16:58:10 +0200 Subject: [PATCH 03/14] BERT finetuning with custom number of classes --- src/backends/torch/torchinputconns.cc | 12 ++- src/backends/torch/torchinputconns.h | 5 +- src/backends/torch/torchlib.cc | 140 +++++++++++++++++++------- src/backends/torch/torchlib.h | 9 +- src/backends/torch/torchmodel.cc | 35 +++++-- src/backends/torch/torchmodel.h | 3 +- 6 files changed, 156 insertions(+), 48 deletions(-) diff --git a/src/backends/torch/torchinputconns.cc b/src/backends/torch/torchinputconns.cc index ed916a46c..3d2890833 100644 --- a/src/backends/torch/torchinputconns.cc +++ b/src/backends/torch/torchinputconns.cc @@ -4,6 +4,7 @@ namespace dd { using namespace torch; +// ===== TorchDataset void TorchDataset::add_batch(std::vector data, std::vector target) { @@ -88,6 +89,8 @@ TorchDataset TorchDataset::split(double start, double stop) } +// ===== TxtTorchInputFileConn + void TxtTorchInputFileConn::transform(const APIData &ad) { try { @@ -98,6 +101,9 @@ void TxtTorchInputFileConn::transform(const APIData &ad) { throw; } + if (!_ordered_words || _characters) + throw InputConnectorBadParamException("Need ordered_words = true with backend torch"); + if (ad.has("parameters") && ad.getobj("parameters").has("input")) { APIData ad_input = ad.getobj("parameters").getobj("input"); @@ -105,9 +111,6 @@ void TxtTorchInputFileConn::transform(const APIData &ad) { _width = ad_input.get("width").get(); } - if (!_ordered_words || _characters) - throw InputConnectorBadParamException("Need ordered_words = true with backend torch"); - fill_dataset(_dataset, _txt); if (!_test_txt.empty()) fill_dataset(_test_dataset, _test_txt); @@ -130,6 +133,9 @@ void TxtTorchInputFileConn::fill_dataset(TorchDataset &dataset, while(tow->has_elt()) { + if (ids.size() >= _width - 1) + break; + std::string word; double val; tow->get_next_elt(word, val); diff --git a/src/backends/torch/torchinputconns.h b/src/backends/torch/torchinputconns.h index 3ba2c3814..59e9a6886 100644 --- a/src/backends/torch/torchinputconns.h +++ b/src/backends/torch/torchinputconns.h @@ -81,7 +81,7 @@ namespace dd torch::Tensor toLongTensor(std::vector &values) { int64_t val_size = values.size(); - return torch::from_blob(&values[0], at::IntList{val_size}, at::kLong); + return torch::from_blob(&values[0], at::IntList{val_size}, at::kLong).clone(); } TorchDataset _dataset; @@ -184,9 +184,10 @@ namespace dd void fill_dataset(TorchDataset &dataset, const std::vector*> &entries); public: + /** width of the input tensor */ int _width = 512; int _height = 0; }; } // namespace dd -#endif // TORCHINPUTCONNS_H \ No newline at end of file +#endif // TORCHINPUTCONNS_H diff --git a/src/backends/torch/torchlib.cc b/src/backends/torch/torchlib.cc index b052c5137..9e2963d97 100644 --- a/src/backends/torch/torchlib.cc +++ b/src/backends/torch/torchlib.cc @@ -68,15 +68,26 @@ namespace dd // ======= TORCH MODULE - TorchModule::TorchModule() {} + TorchModule::TorchModule() : _device{"cpu"} {} c10::IValue TorchModule::forward(std::vector source) { if (_traced) { - source = { _traced->forward(source) }; + auto output = _traced->forward(source); + if (output.isTensorList()) { + auto &elems = output.toTensorList()->elements(); + source = std::vector(elems.begin(), elems.end()); + } + else if (output.isTuple()) { + auto &elems = output.toTuple()->elements(); + source = std::vector(elems.begin(), elems.end()); + } + else { + source = { output }; + } } - c10::IValue out_val = source.at(0); + c10::IValue out_val = source.at(_classif_in); if (_classif) { out_val = _classif->forward(to_tensor_safe(out_val)); @@ -100,12 +111,31 @@ namespace dd void TorchModule::save_checkpoint(TorchModel &model, const std::string &name) { if (_traced) - _traced->save(model._repo + "/checkpoint-" + name + ".pt"); + _traced->save(model._repo + "/checkpoint-" + name + "-trace.pt"); + if (_classif) + torch::save(_classif, model._repo + "/checkpoint-" + name + ".pt"); } - void TorchModule::load_checkpoint(const std::string &filename) + void TorchModule::load(TorchModel &model) { - // Not yet implemented + if (!model._traced.empty()) + _traced = torch::jit::load(model._traced, _device); + if (!model._weights.empty()) + torch::load(_classif, model._weights); + } + + void TorchModule::eval() { + if (_traced) + _traced->eval(); + if (_classif) + _classif->eval(); + } + + void TorchModule::train() { + if (_traced) + _traced->train(); + if (_classif) + _classif->train(); } @@ -124,6 +154,7 @@ namespace dd { this->_libname = "torch"; _module = std::move(tl._module); + _template = tl._template; _nclasses = tl._nclasses; _device = tl._device; } @@ -138,27 +169,45 @@ namespace dd template void TorchLib::init_mllib(const APIData &lib_ad) { + bool classification = true; + bool finetuning = false; bool gpu = false; int gpuid = -1; - if (lib_ad.has("gpu")) { + if (lib_ad.has("gpu")) gpu = lib_ad.get("gpu").get() && torch::cuda::is_available(); - } if (lib_ad.has("gpuid")) gpuid = lib_ad.get("gpuid").get(); - if (lib_ad.has("nclasses")) { + if (lib_ad.has("nclasses")) + { + classification = true; _nclasses = lib_ad.get("nclasses").get(); } + if (lib_ad.has("template")) + _template = lib_ad.get("template").get(); + if (lib_ad.has("finetuning")) + finetuning = lib_ad.get("finetuning").get(); _device = gpu ? torch::Device(DeviceType::CUDA, gpuid) : torch::Device(DeviceType::CPU); + _module._device = _device; - if (typeid(TInputConnectorStrategy) == typeid(TxtTorchInputFileConn)) { - //_attention = true; + // Create the model + if (this->_mlmodel._traced.empty()) + throw MLLibInternalException("This template requires a traced net"); + + if (_template == "bert-classification") + { + if (!classification) + throw MLLibBadParamException("nclasses not specified"); + + // XXX: dont hard code BERT output size + _module._classif = nn::Linear(768, _nclasses); + _module._classif->to(_device); + + _module._classif_in = 1; } - // TODO load classification layer, or create it according to the number of classes - _module._traced = torch::jit::load(this->_mlmodel._model_file, _device); - _module._traced->eval(); + _module.load(this->_mlmodel); } template @@ -195,7 +244,7 @@ namespace dd int64_t batch_count = (inputc._dataset.cache_size() - 1) / batch_size + 1; // logging parameters - int64_t log_batch_period = 1; + int64_t log_batch_period = 20; if (ad_mllib.has("iterations")) iterations = ad_mllib.get("iterations").get(); @@ -220,8 +269,24 @@ namespace dd } // create solver - // no care about solver type yet - optim::Adam optimizer(_module.parameters(), optim::AdamOptions(base_lr)); + std::unique_ptr optimizer; + + if (solver_type == "ADAM") + optimizer = std::unique_ptr( + new optim::Adam(_module.parameters(), optim::AdamOptions(base_lr))); + else if (solver_type == "RMSPROP") + optimizer = std::unique_ptr( + new optim::RMSprop(_module.parameters(), optim::RMSpropOptions(base_lr))); + else if (solver_type == "ADAGRAD") + optimizer = std::unique_ptr( + new optim::Adagrad(_module.parameters(), optim::AdagradOptions(base_lr))); + else + { + if (solver_type != "SGD") + this->_logger->warn("Solver type {} not found, using SGD", solver_type); + optimizer = std::unique_ptr( + new optim::SGD(_module.parameters(), optim::SGDOptions(base_lr))); + } // create dataloader auto dataloader = torch::data::make_data_loader( @@ -232,6 +297,7 @@ namespace dd for (int64_t epoch = 0; epoch < iterations; ++epoch) { + _module.train(); this->add_meas("iteration", epoch); this->_logger->info("Iteration {}", epoch); int batch_id = 0; @@ -244,13 +310,12 @@ namespace dd Tensor y_pred = to_tensor_safe(_module.forward(in_vals)); Tensor y = to_one_hot(example.target.at(0), _nclasses).to(_device); - // TODO let loss be a parameter auto loss = torch::mse_loss(y_pred, y); double loss_val = loss.item(); - optimizer.zero_grad(); + optimizer->zero_grad(); loss.backward(); - optimizer.step(); + optimizer->step(); this->add_meas("train_loss", loss_val); this->add_meas_per_iter("train_loss", loss_val); @@ -261,34 +326,38 @@ namespace dd ++batch_id; } - if (epoch > 0 && epoch % test_interval == 0 && !eval_dataset.empty()) + int64_t elapsed_it = epoch + 1; + if (elapsed_it % test_interval == 0 && !eval_dataset.empty()) { APIData meas_out; test(ad, eval_dataset, test_batch_size, meas_out); - eval_dataset = inputc._test_dataset; // XXX eval_dataset is moved. find a way to unmove it - APIData meas_obj = meas_out.getobj("measure"); + APIData meas_obj = meas_out.getobj("measure"); std::vector meas_names = meas_obj.list_keys(); for (auto name : meas_names) { - double mval = meas_obj.get(name).get(); - this->_logger->info("{}={}", name, mval); - this->add_meas(name, mval); - this->add_meas_per_iter(name, mval); + if (name != "cmdiag" && name != "cmfull" && name != "labels") + { + double mval = meas_obj.get(name).get(); + this->_logger->info("{}={}", name, mval); + this->add_meas(name, mval); + this->add_meas_per_iter(name, mval); + } } } - int64_t check_id = epoch + 1; - if ((save_period != 0 && check_id % save_period == 0) || check_id == iterations) + if ((save_period != 0 && elapsed_it % save_period == 0) || elapsed_it == iterations) { - this->_logger->info("Saving checkpoint after {} iterations", check_id); - _module.save_checkpoint(this->_mlmodel, std::to_string(check_id)); + this->_logger->info("Saving checkpoint after {} iterations", elapsed_it); + _module.save_checkpoint(this->_mlmodel, std::to_string(elapsed_it)); } } test(ad, inputc._test_dataset, test_batch_size, out); - // TODO make model ready for predict after training + // Update model after training + this->_mlmodel.read_from_repository(this->_logger); + this->_mlmodel.read_corresp_file(); inputc.response_params(out); this->_logger->info("Training done."); @@ -310,6 +379,8 @@ namespace dd } torch::Device cpu("cpu"); + _module.eval(); + if (output_params.has("measure")) { APIData meas_out; @@ -386,10 +457,11 @@ namespace dd // std::move may lead to unexpected behaviour from the input connector auto dataloader = torch::data::make_data_loader( - std::move(dataset), + dataset, data::DataLoaderOptions(batch_size) ); + _module.eval(); int entry_id = 0; for (TorchBatch &batch : *dataloader) { @@ -413,7 +485,7 @@ namespace dd ad_res.add(std::to_string(entry_id), bad); ++entry_id; } - this->_logger->info("Testing: {}/{} entries processed", entry_id, test_size); + // this->_logger->info("Testing: {}/{} entries processed", entry_id, test_size); } ad_res.add("iteration",this->get_meas("iteration")); diff --git a/src/backends/torch/torchlib.h b/src/backends/torch/torchlib.h index 9bf0c9ab2..5b84e6d85 100644 --- a/src/backends/torch/torchlib.h +++ b/src/backends/torch/torchlib.h @@ -42,10 +42,16 @@ namespace dd void save_checkpoint(TorchModel &model, const std::string &name); - void load_checkpoint(const std::string &name); + void load(TorchModel &model); + + void eval(); + void train(); public: std::shared_ptr _traced; torch::nn::Linear _classif = nullptr; + + torch::Device _device; + int _classif_in = 0;/** &logger) { + const std::string traced = "trace.pt"; + const std::string weights = ".pt"; + const std::string corresp = "corresp"; + std::unordered_set files; int err = fileops::list_directory(_repo, true, false, false, files); @@ -34,18 +38,35 @@ namespace dd return 1; } + std::string tracedf,weightsf,correspf; + int traced_t = -1, weights_t = -1, corresp_t = -1; + for (const auto &file : files) { - if (file.find(".pt") != std::string::npos) { - _model_file = file; + long int lm = fileops::file_last_modif(file); + if (file.find(traced) != std::string::npos) { + if (traced_t < lm) { + tracedf = file; + traced_t = lm; + } } - else if (file.find("corresp") != std::string::npos) { - _corresp = file; + else if (file.find(weights) != std::string::npos) { + if (weights_t < lm) { + weightsf = file; + weights_t = lm; + } + } + else if (file.find(corresp) != std::string::npos) { + if (corresp_t < lm) { + correspf = file; + corresp_t = lm; + } } } - // TODO detect and list checkpoints names - // detect last checkpoint - // detect classification layer for BERT + _traced = tracedf; + _weights = weightsf; + _corresp = correspf; + return 0; } } \ No newline at end of file diff --git a/src/backends/torch/torchmodel.h b/src/backends/torch/torchmodel.h index 2b35294c1..dc2df0032 100644 --- a/src/backends/torch/torchmodel.h +++ b/src/backends/torch/torchmodel.h @@ -48,7 +48,8 @@ namespace dd int read_from_repository(const std::shared_ptr &logger); public: - std::string _model_file; + std::string _traced;/**< path of the traced part of the net. */ + std::string _weights;/**< path of the weights of the net. */ }; } From dbc611cac990b3c2b4608c8f7d187da83db5eb5d Mon Sep 17 00:00:00 2001 From: Louis J Date: Tue, 13 Aug 2019 14:38:30 +0200 Subject: [PATCH 04/14] Add self supervised Masked LM learning --- src/backends/torch/torchinputconns.cc | 61 ++++++- src/backends/torch/torchinputconns.h | 38 +++- src/backends/torch/torchlib.cc | 253 +++++++++++++++++++------- src/backends/torch/torchlib.h | 8 +- src/backends/torch/torchmodel.cc | 18 +- 5 files changed, 291 insertions(+), 87 deletions(-) diff --git a/src/backends/torch/torchinputconns.cc b/src/backends/torch/torchinputconns.cc index 3d2890833..92908e0cf 100644 --- a/src/backends/torch/torchinputconns.cc +++ b/src/backends/torch/torchinputconns.cc @@ -111,25 +111,70 @@ void TxtTorchInputFileConn::transform(const APIData &ad) { _width = ad_input.get("width").get(); } + _cls_pos = _vocab.at("[CLS]")._pos; + _sep_pos = _vocab.at("[SEP]")._pos; + _unk_pos = _vocab.at("[UNK]")._pos; + _mask_id = _vocab.at("[MASK]")._pos; + fill_dataset(_dataset, _txt); if (!_test_txt.empty()) fill_dataset(_test_dataset, _test_txt); } +TorchBatch TxtTorchInputFileConn::generate_masked_lm_batch(const TorchBatch &example) +{ + std::uniform_real_distribution uniform(0, 1); + std::uniform_int_distribution vocab_distrib(0, vocab_size()); + Tensor input_ids = example.data.at(0).clone(); + Tensor lm_labels = torch::ones_like(input_ids, TensorOptions(kLong)) * -1; + + // mask random tokens + auto input_acc = input_ids.accessor(); + auto att_mask_acc = example.data.at(2).accessor(); + auto labels_acc = lm_labels.accessor(); + for (int i = 0; i < input_ids.size(0); ++i) + { + int j = 1; // skip [CLS] token + while (j < input_ids.size(1) && att_mask_acc[i][j] != 0) + { + double rand_num = uniform(_rng); + if (rand_num < _lm_params._change_prob && input_acc[i][j] != _sep_pos) + { + labels_acc[i][j] = input_acc[i][j]; + + rand_num = uniform(_rng); + if (rand_num < _lm_params._mask_prob) + { + input_acc[i][j] = mask_id(); + } + else if (rand_num < _lm_params._mask_prob + _lm_params._rand_prob) + { + input_acc[i][j] = vocab_distrib(_rng); + } + } + ++j; + } + } + + TorchBatch output; + output.target.push_back(lm_labels); + output.data.push_back(input_ids); + for (int i = 1; i < example.data.size(); ++i) + { + output.data.push_back(example.data[i]); + } + return output; +} + void TxtTorchInputFileConn::fill_dataset(TorchDataset &dataset, const std::vector*> &entries) { - int cls_pos = _vocab.at("[CLS]")._pos; - int sep_pos = _vocab.at("[SEP]")._pos; - int unk_pos = _vocab.at("[UNK]")._pos; - for (auto *te : entries) { TxtOrderedWordsEntry *tow = static_cast(te); tow->reset(); - std::vector ids; - ids.push_back(cls_pos); + ids.push_back(_cls_pos); while(tow->has_elt()) { @@ -147,11 +192,11 @@ void TxtTorchInputFileConn::fill_dataset(TorchDataset &dataset, } else { - ids.push_back(unk_pos); + ids.push_back(_unk_pos); } } - ids.push_back(sep_pos); + ids.push_back(_sep_pos); at::Tensor ids_tensor = toLongTensor(ids); at::Tensor mask_tensor = torch::ones_like(ids_tensor); diff --git a/src/backends/torch/torchinputconns.h b/src/backends/torch/torchinputconns.h index 59e9a6886..514a38589 100644 --- a/src/backends/torch/torchinputconns.h +++ b/src/backends/torch/torchinputconns.h @@ -51,7 +51,7 @@ namespace dd void reset(); - // Size of data loaded in memory + /// Size of data loaded in memory size_t cache_size() const { return _batches.size(); } c10::optional size() const override { @@ -62,20 +62,31 @@ namespace dd c10::optional get_batch(BatchRequestType request) override; - // Returns a batch containing all the cached data + /// Returns a batch containing all the cached data TorchBatch get_cached(); - // Split a percentage of this dataset + /// Split a percentage of this dataset TorchDataset split(double start, double stop); }; + struct MaskedLMParams + { + double _change_prob = 0.15; /**< When masked LM learning, probability of changing a token (mask/randomize/keep). */ + double _mask_prob = 0.8; /**< When masked LM learning, probability of masking a token. */ + double _rand_prob = 0.1; /**< When masked LM learning, probability of randomizing a token. */ + }; + class TorchInputInterface { public: TorchInputInterface() {} - TorchInputInterface(const TorchInputInterface &i) {} + TorchInputInterface(const TorchInputInterface &i) { + _lm_params = i._lm_params; + _dataset = i._dataset; + _test_dataset = i._test_dataset; + } ~TorchInputInterface() {} @@ -84,8 +95,15 @@ namespace dd return torch::from_blob(&values[0], at::IntList{val_size}, at::kLong).clone(); } + TorchBatch generate_masked_lm_batch(const TorchBatch &example) { return {}; } + + int64_t mask_id() const { return 0; } + int64_t vocab_size() const { return 0; } + TorchDataset _dataset; TorchDataset _test_dataset; + + MaskedLMParams _lm_params; }; class ImgTorchInputFileConn : public ImgInputFileConn, public TorchInputInterface @@ -180,13 +198,25 @@ namespace dd return _height; } + int64_t mask_id() const { return _mask_id; } + + int64_t vocab_size() const { return _vocab.size(); } + void transform(const APIData &ad); + TorchBatch generate_masked_lm_batch(const TorchBatch &example); + void fill_dataset(TorchDataset &dataset, const std::vector*> &entries); public: /** width of the input tensor */ int _width = 512; int _height = 0; + std::mt19937 _rng; + + int64_t _mask_id = -1; /**< ID of mask token in the vocabulary. */ + int64_t _cls_pos = -1; + int64_t _sep_pos = -1; + int64_t _unk_pos = -1; }; } // namespace dd diff --git a/src/backends/torch/torchlib.cc b/src/backends/torch/torchlib.cc index 9e2963d97..c7e3b7f14 100644 --- a/src/backends/torch/torchlib.cc +++ b/src/backends/torch/torchlib.cc @@ -111,9 +111,9 @@ namespace dd void TorchModule::save_checkpoint(TorchModel &model, const std::string &name) { if (_traced) - _traced->save(model._repo + "/checkpoint-" + name + "-trace.pt"); + _traced->save(model._repo + "/checkpoint-" + name + ".pt"); if (_classif) - torch::save(_classif, model._repo + "/checkpoint-" + name + ".pt"); + torch::save(_classif, model._repo + "/checkpoint-" + name + ".ptw"); } void TorchModule::load(TorchModel &model) @@ -157,6 +157,7 @@ namespace dd _template = tl._template; _nclasses = tl._nclasses; _device = tl._device; + _masked_lm = tl._masked_lm; } template @@ -195,25 +196,37 @@ namespace dd if (this->_mlmodel._traced.empty()) throw MLLibInternalException("This template requires a traced net"); - if (_template == "bert-classification") + if (_masked_lm) { - if (!classification) - throw MLLibBadParamException("nclasses not specified"); - - // XXX: dont hard code BERT output size - _module._classif = nn::Linear(768, _nclasses); - _module._classif->to(_device); - - _module._classif_in = 1; + _module._classif_in = 0; } + else + { + if (_template == "bert-classification") + { + if (!classification) + throw MLLibBadParamException("nclasses not specified"); + + // XXX: dont hard code BERT output size + _module._classif = nn::Linear(768, _nclasses); + _module._classif->to(_device); + _module._classif_in = 1; + } + } + + this->_logger->info("Loading ml model from file {}.", this->_mlmodel._traced); + this->_logger->info("Loading weights from file {}.", this->_mlmodel._weights); _module.load(this->_mlmodel); + + this->_mltype = "classification"; } template void TorchLib::clear_mllib(const APIData &ad) { - + std::vector extensions{".json"}; + fileops::remove_directory_files(this->_mlmodel._repo, extensions); } template @@ -238,10 +251,10 @@ namespace dd std::string solver_type = "SGD"; double base_lr = 0.0001; int64_t batch_size = 5; + int64_t iter_size = 1; int64_t test_batch_size = 1; int64_t test_interval = 1; int64_t save_period = 0; - int64_t batch_count = (inputc._dataset.cache_size() - 1) / batch_size + 1; // logging parameters int64_t log_batch_period = 20; @@ -256,11 +269,16 @@ namespace dd test_interval = ad_mllib.get("test_interval").get(); if (ad_mllib.has("batch_size")) batch_size = ad_mllib.get("batch_size").get(); + if (ad_mllib.has("iter_size")) + iter_size = ad_mllib.get("iter_size").get(); if (ad_mllib.has("test_batch_size")) test_batch_size = ad_mllib.get("test_batch_size").get(); if (ad_mllib.has("save_period")) save_period = ad_mllib.get("save_period").get(); + if (iter_size <= 0) + iter_size = 1; + // create dataset for evaluation during training TorchDataset eval_dataset; if (!inputc._test_dataset.empty()) @@ -287,73 +305,128 @@ namespace dd optimizer = std::unique_ptr( new optim::SGD(_module.parameters(), optim::SGDOptions(base_lr))); } + optimizer->zero_grad(); + _module.train(); // create dataloader auto dataloader = torch::data::make_data_loader( std::move(inputc._dataset), data::DataLoaderOptions(batch_size) ); + this->_logger->info("Training for {} iterations", iterations); + int it = 0; + int batch_id = 0; + using namespace std::chrono; - for (int64_t epoch = 0; epoch < iterations; ++epoch) + while (it < iterations) { - _module.train(); - this->add_meas("iteration", epoch); - this->_logger->info("Iteration {}", epoch); - int batch_id = 0; + double train_loss = 0; + double avg_it_time = 0; for (TorchBatch &example : *dataloader) { + auto tstart = system_clock::now(); + TorchBatch batch; + if (_masked_lm) + { + batch = inputc.generate_masked_lm_batch(example); + } + else + { + batch = example; + batch.target.at(0) = to_one_hot(batch.target.at(0), _nclasses); + } std::vector in_vals; - for (Tensor tensor : example.data) + for (Tensor tensor : batch.data) in_vals.push_back(tensor.to(_device)); - Tensor y_pred = to_tensor_safe(_module.forward(in_vals)); - Tensor y = to_one_hot(example.target.at(0), _nclasses).to(_device); - - auto loss = torch::mse_loss(y_pred, y); - double loss_val = loss.item(); + Tensor y = batch.target.at(0).to(_device); - optimizer->zero_grad(); - loss.backward(); - optimizer->step(); + Tensor y_pred; + try + { + y_pred = to_tensor_safe(_module.forward(in_vals)); + } + catch (std::exception &e) + { + throw MLLibInternalException(std::string("Libtorch error:") + e.what()); + } - this->add_meas("train_loss", loss_val); - this->add_meas_per_iter("train_loss", loss_val); - if (log_batch_period != 0 && (batch_id + 1) % log_batch_period == 0) + Tensor loss; + if (_masked_lm) { - this->_logger->info("Batch {}/{}: loss is {}", batch_id + 1, batch_count, loss_val); + loss = torch::nll_loss( + torch::log_softmax(y_pred.view(IntList{-1, y_pred.size(2)}), 1), + y.view(IntList{-1}), + {}, Reduction::Mean, -1 + ); } - ++batch_id; - } - - int64_t elapsed_it = epoch + 1; - if (elapsed_it % test_interval == 0 && !eval_dataset.empty()) - { - APIData meas_out; - test(ad, eval_dataset, test_batch_size, meas_out); - APIData meas_obj = meas_out.getobj("measure"); - std::vector meas_names = meas_obj.list_keys(); + else + { + // TODO: Better choice for the loss + loss = torch::mse_loss(y_pred, y); + } + if (iter_size > 1) + loss /= iter_size; - for (auto name : meas_names) + double loss_val = loss.item(); + train_loss += loss_val; + loss.backward(); + auto tstop = system_clock::now(); + avg_it_time += duration_cast(tstop - tstart).count(); + + if ((batch_id + 1) % iter_size == 0) { - if (name != "cmdiag" && name != "cmfull" && name != "labels") + optimizer->step(); + optimizer->zero_grad(); + avg_it_time /= iter_size; + this->add_meas("iteration", it); + this->add_meas("iter_time", avg_it_time); + this->add_meas("remain_time", avg_it_time * iter_size * (iterations - it) / 1000.0); + this->add_meas("train_loss", train_loss); + this->add_meas_per_iter("train_loss", train_loss); + int64_t elapsed_it = it + 1; + if (log_batch_period != 0 && elapsed_it % log_batch_period == 0) + { + this->_logger->info("Iteration {}/{}: loss is {}", elapsed_it, iterations, train_loss); + } + train_loss = 0; + + if (elapsed_it % test_interval == 0 && elapsed_it != iterations && !eval_dataset.empty()) { - double mval = meas_obj.get(name).get(); - this->_logger->info("{}={}", name, mval); - this->add_meas(name, mval); - this->add_meas_per_iter(name, mval); + APIData meas_out; + test(ad, inputc, eval_dataset, test_batch_size, meas_out); + APIData meas_obj = meas_out.getobj("measure"); + std::vector meas_names = meas_obj.list_keys(); + + for (auto name : meas_names) + { + if (name != "cmdiag" && name != "cmfull" && name != "labels") + { + double mval = meas_obj.get(name).get(); + this->_logger->info("{}={}", name, mval); + this->add_meas(name, mval); + this->add_meas_per_iter(name, mval); + } + } } + + if ((save_period != 0 && elapsed_it % save_period == 0) || elapsed_it == iterations) + { + this->_logger->info("Saving checkpoint after {} iterations", elapsed_it); + _module.save_checkpoint(this->_mlmodel, std::to_string(elapsed_it)); + } + ++it; + + if (it >= iterations) + break; } - } - if ((save_period != 0 && elapsed_it % save_period == 0) || elapsed_it == iterations) - { - this->_logger->info("Saving checkpoint after {} iterations", elapsed_it); - _module.save_checkpoint(this->_mlmodel, std::to_string(elapsed_it)); + ++batch_id; } } - test(ad, inputc._test_dataset, test_batch_size, out); + test(ad, inputc, inputc._test_dataset, test_batch_size, out); // Update model after training this->_mlmodel.read_from_repository(this->_logger); @@ -384,7 +457,7 @@ namespace dd if (output_params.has("measure")) { APIData meas_out; - test(ad, inputc._dataset, 1, meas_out); + test(ad, inputc, inputc._dataset, 1, meas_out); meas_out.erase("iteration"); out.add("measure", meas_out.getobj("measure")); return 0; @@ -447,6 +520,7 @@ namespace dd template int TorchLib::test(const APIData &ad, + TInputConnectorStrategy &inputc, TorchDataset &dataset, int batch_size, APIData &out) @@ -454,34 +528,85 @@ namespace dd APIData ad_res; APIData ad_out = ad.getobj("parameters").getobj("output"); int test_size = dataset.cache_size(); + int nclasses = _masked_lm ? inputc.vocab_size() : _nclasses; + + // confusion matrix is irrelevant to masked_lm training + if (_masked_lm && ad_out.has("measure")) + { + auto meas = ad_out.get("measure").get>(); + std::vector::iterator it; + if ((it = std::find(meas.begin(), meas.end(), "cmfull")) != meas.end()) + meas.erase(it); + if ((it = std::find(meas.begin(), meas.end(), "cmdiag")) != meas.end()) + meas.erase(it); + ad_out.add("measure", meas); + } - // std::move may lead to unexpected behaviour from the input connector auto dataloader = torch::data::make_data_loader( dataset, data::DataLoaderOptions(batch_size) ); + torch::Device cpu("cpu"); _module.eval(); int entry_id = 0; - for (TorchBatch &batch : *dataloader) + for (TorchBatch &example : *dataloader) { + TorchBatch batch; + if (_masked_lm) + { + batch = inputc.generate_masked_lm_batch(example); + } + else + { + batch = example; + } std::vector in_vals; for (Tensor tensor : batch.data) in_vals.push_back(tensor.to(_device)); - Tensor output = torch::softmax(to_tensor_safe(_module.forward(in_vals)), 1); + + Tensor output; + try + { + output = to_tensor_safe(_module.forward(in_vals)); + } + catch (std::exception &e) + { + throw MLLibInternalException(std::string("Libtorch error:") + e.what()); + } + if (batch.target.empty()) throw MLLibBadParamException("Missing label on data while testing"); Tensor labels = batch.target[0]; - - for (int j = 0; j < labels.size(0); ++j) { + + if (_masked_lm) + { + output = output.view(IntList{-1, output.size(2)}); + labels = labels.view(IntList{-1}); + } + output = torch::softmax(output, 1).to(cpu); + auto output_acc = output.accessor(); + auto labels_acc = labels.accessor(); + + for (int j = 0; j < labels.size(0); ++j) + { + if (_masked_lm && labels_acc[j] == -1) + continue; + APIData bad; std::vector predictions; - for (int c = 0; c < _nclasses; c++) + for (int c = 0; c < nclasses; c++) { - predictions.push_back(output[j][c].item()); + predictions.push_back(output_acc[j][c]); } - bad.add("target", labels[j].item()); + bad.add("target", static_cast(labels_acc[j])); bad.add("pred", predictions); + // auto &tinputc = reinterpret_cast(inputc); + /* + std::cout << "target: " << labels_acc[j] << std::endl; + std::cout << "masking: " << batch.data[0][0][j] << " " << batch.data[2][0][j] << std::endl; + std::cout << "pred: " << std::distance(predictions.begin(), std::max_element(predictions.begin(), predictions.end())) << std::endl << std::endl; + */ ad_res.add(std::to_string(entry_id), bad); ++entry_id; } @@ -491,10 +616,10 @@ namespace dd ad_res.add("iteration",this->get_meas("iteration")); ad_res.add("train_loss",this->get_meas("train_loss")); std::vector clnames; - for (int i=0;i<_nclasses;i++) + for (int i=0;i< nclasses;i++) clnames.push_back(this->_mlmodel.get_hcorresp(i)); ad_res.add("clnames", clnames); - ad_res.add("nclasses", _nclasses); + ad_res.add("nclasses", nclasses); ad_res.add("batch_size", entry_id); // here batch_size = tested entries count SupervisedOutput::measure(ad_res, ad_out, out); return 0; diff --git a/src/backends/torch/torchlib.h b/src/backends/torch/torchlib.h index 5b84e6d85..3c9ade21a 100644 --- a/src/backends/torch/torchlib.h +++ b/src/backends/torch/torchlib.h @@ -22,6 +22,8 @@ #ifndef TORCHLIB_H #define TORCHLIB_H +#include + #include #include "apidata.h" @@ -51,7 +53,7 @@ namespace dd torch::nn::Linear _classif = nullptr; torch::Device _device; - int _classif_in = 0;/** &logger) { - const std::string traced = "trace.pt"; - const std::string weights = ".pt"; + const std::string weights = ".ptw"; + const std::string traced = ".pt"; const std::string corresp = "corresp"; std::unordered_set files; @@ -43,18 +43,18 @@ namespace dd for (const auto &file : files) { long int lm = fileops::file_last_modif(file); - if (file.find(traced) != std::string::npos) { - if (traced_t < lm) { - tracedf = file; - traced_t = lm; - } - } - else if (file.find(weights) != std::string::npos) { + if (file.find(weights) != std::string::npos) { if (weights_t < lm) { weightsf = file; weights_t = lm; } } + else if (file.find(traced) != std::string::npos) { + if (traced_t < lm) { + tracedf = file; + traced_t = lm; + } + } else if (file.find(corresp) != std::string::npos) { if (corresp_t < lm) { correspf = file; From 8cc4914add09075ce5ab8c15438c250e1cc06cf0 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 4 Sep 2019 15:18:54 +0200 Subject: [PATCH 05/14] Save solver checkpoint along with model --- src/backends/torch/torchlib.cc | 9 +++++++++ src/backends/torch/torchmodel.cc | 16 ++++++++++++---- src/backends/torch/torchmodel.h | 3 ++- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/backends/torch/torchlib.cc b/src/backends/torch/torchlib.cc index c7e3b7f14..c13c41e0b 100644 --- a/src/backends/torch/torchlib.cc +++ b/src/backends/torch/torchlib.cc @@ -227,6 +227,7 @@ namespace dd { std::vector extensions{".json"}; fileops::remove_directory_files(this->_mlmodel._repo, extensions); + this->_logger->info("Torchlib service cleared"); } template @@ -305,6 +306,12 @@ namespace dd optimizer = std::unique_ptr( new optim::SGD(_module.parameters(), optim::SGDOptions(base_lr))); } + // reload solver + if (!this->_mlmodel._sstate.empty()) + { + this->_logger->info("Reload solver from {}", this->_mlmodel._sstate); + torch::load(*optimizer, this->_mlmodel._sstate); + } optimizer->zero_grad(); _module.train(); @@ -415,6 +422,8 @@ namespace dd { this->_logger->info("Saving checkpoint after {} iterations", elapsed_it); _module.save_checkpoint(this->_mlmodel, std::to_string(elapsed_it)); + // Save optimizer + torch::save(*optimizer, this->_mlmodel._repo + "/solver-" + std::to_string(elapsed_it) + ".pt"); } ++it; diff --git a/src/backends/torch/torchmodel.cc b/src/backends/torch/torchmodel.cc index 01d1997d9..95d77866b 100644 --- a/src/backends/torch/torchmodel.cc +++ b/src/backends/torch/torchmodel.cc @@ -29,6 +29,7 @@ namespace dd const std::string weights = ".ptw"; const std::string traced = ".pt"; const std::string corresp = "corresp"; + const std::string sstate = "solver"; std::unordered_set files; int err = fileops::list_directory(_repo, true, false, false, files); @@ -38,12 +39,18 @@ namespace dd return 1; } - std::string tracedf,weightsf,correspf; - int traced_t = -1, weights_t = -1, corresp_t = -1; + std::string tracedf,weightsf,correspf,sstatef; + int traced_t = -1, weights_t = -1, corresp_t = -1, sstate_t = -1; for (const auto &file : files) { long int lm = fileops::file_last_modif(file); - if (file.find(weights) != std::string::npos) { + if (file.find(sstate) != std::string::npos) { + if (sstate_t < lm) { + sstatef = file; + sstate_t = lm; + } + } + else if (file.find(weights) != std::string::npos) { if (weights_t < lm) { weightsf = file; weights_t = lm; @@ -66,7 +73,8 @@ namespace dd _traced = tracedf; _weights = weightsf; _corresp = correspf; + _sstate = sstatef; return 0; } -} \ No newline at end of file +} diff --git a/src/backends/torch/torchmodel.h b/src/backends/torch/torchmodel.h index dc2df0032..a3e695d9a 100644 --- a/src/backends/torch/torchmodel.h +++ b/src/backends/torch/torchmodel.h @@ -50,7 +50,8 @@ namespace dd public: std::string _traced;/**< path of the traced part of the net. */ std::string _weights;/**< path of the weights of the net. */ + std::string _sstate;/**< current solver state to resume training */ }; } -#endif // TORCHMODEL_H \ No newline at end of file +#endif // TORCHMODEL_H From a5a04910ffa537f642ba5a8bc5c57aa69ce0d3a3 Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 6 Sep 2019 15:08:22 +0200 Subject: [PATCH 06/14] Ensure label is of correct dimension --- src/backends/torch/torchlib.cc | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/backends/torch/torchlib.cc b/src/backends/torch/torchlib.cc index c13c41e0b..4ddef846b 100644 --- a/src/backends/torch/torchlib.cc +++ b/src/backends/torch/torchlib.cc @@ -586,12 +586,10 @@ namespace dd if (batch.target.empty()) throw MLLibBadParamException("Missing label on data while testing"); - Tensor labels = batch.target[0]; - + Tensor labels = batch.target[0].view(IntList{-1}); if (_masked_lm) { output = output.view(IntList{-1, output.size(2)}); - labels = labels.view(IntList{-1}); } output = torch::softmax(output, 1).to(cpu); auto output_acc = output.accessor(); @@ -610,12 +608,6 @@ namespace dd } bad.add("target", static_cast(labels_acc[j])); bad.add("pred", predictions); - // auto &tinputc = reinterpret_cast(inputc); - /* - std::cout << "target: " << labels_acc[j] << std::endl; - std::cout << "masking: " << batch.data[0][0][j] << " " << batch.data[2][0][j] << std::endl; - std::cout << "pred: " << std::distance(predictions.begin(), std::max_element(predictions.begin(), predictions.end())) << std::endl << std::endl; - */ ad_res.add(std::to_string(entry_id), bad); ++entry_id; } From 00dc152a246129160b52cfd7c5b657fba06198a8 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 10 Sep 2019 16:03:39 +0200 Subject: [PATCH 07/14] Fix masked_lm, add more explicit error message --- src/backends/torch/torchlib.cc | 9 +++++++-- src/txtinputfileconn.cc | 4 +++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/backends/torch/torchlib.cc b/src/backends/torch/torchlib.cc index 4ddef846b..c7e3b4324 100644 --- a/src/backends/torch/torchlib.cc +++ b/src/backends/torch/torchlib.cc @@ -31,7 +31,9 @@ namespace dd { void add_parameters(std::shared_ptr module, std::vector ¶ms) { for (const auto &slot : module->get_parameters()) { - params.push_back(slot.value().toTensor()); + Tensor tensor = slot.value().toTensor(); + if (tensor.requires_grad()) + params.push_back(slot.value().toTensor()); } for (auto child : module->get_modules()) { add_parameters(child, params); @@ -188,6 +190,8 @@ namespace dd _template = lib_ad.get("template").get(); if (lib_ad.has("finetuning")) finetuning = lib_ad.get("finetuning").get(); + if (lib_ad.has("masked_lm")) + _masked_lm = lib_ad.get("masked_lm").get(); _device = gpu ? torch::Device(DeviceType::CUDA, gpuid) : torch::Device(DeviceType::CPU); _module._device = _device; @@ -479,7 +483,8 @@ namespace dd Tensor output; try { - output = torch::softmax(to_tensor_safe(_module.forward(in_vals)), 1); + output = to_tensor_safe(_module.forward(in_vals)); + output = torch::softmax(output, 1).to(cpu); } catch (std::exception &e) { diff --git a/src/txtinputfileconn.cc b/src/txtinputfileconn.cc index 143be0015..7814252bd 100644 --- a/src/txtinputfileconn.cc +++ b/src/txtinputfileconn.cc @@ -488,7 +488,9 @@ namespace dd while(getline(in,line)) { std::vector tokens = dd_utils::split(line,_vocab_sep); - std::string key = tokens.at(0); + if (tokens.size() < 2) + throw InputConnectorBadParamException("Error in vocabulary file " + vocabfname); + std::string key = tokens.at(0); int pos = std::atoi(tokens.at(1).c_str()); _vocab.emplace(std::make_pair(key,Word(pos))); } From aca96147ef406b977231794ef81866c30ed6f214 Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 12 Sep 2019 14:12:53 +0200 Subject: [PATCH 08/14] Add script to trace huggingface models --- tools/torch/README.md | 8 ++ tools/torch/trace_pytorch_transformers.py | 157 ++++++++++++++++++++++ 2 files changed, 165 insertions(+) create mode 100755 tools/torch/trace_pytorch_transformers.py diff --git a/tools/torch/README.md b/tools/torch/README.md index 57d8d189d..ce2b70dce 100644 --- a/tools/torch/README.md +++ b/tools/torch/README.md @@ -8,3 +8,11 @@ Utility script to trace the models included in torchvision. Requires torchvision ``` pip3 install --user torchvision ``` + +* `trace_pytorch_transformers.py` + +Utility script to trace NLP models from Huggingface's pytorch-transformers. Requires pytorch-transformers 1.1: +``` +pip3 install --user pytorch-transformers==1.1 +``` +At the moment a CUDA model can not be convertible to CPU and vice versa. This may change with a future version of pytorch-transformers. diff --git a/tools/torch/trace_pytorch_transformers.py b/tools/torch/trace_pytorch_transformers.py new file mode 100755 index 000000000..0fd208a27 --- /dev/null +++ b/tools/torch/trace_pytorch_transformers.py @@ -0,0 +1,157 @@ +#!/usr/bin/python3 +import sys +import os +import argparse +import logging + +import torch +import torch.nn as nn +import pytorch_transformers as M + +parser = argparse.ArgumentParser(description="Trace NLP models from pytorch-transformers") +parser.add_argument('models', type=str, nargs='*', help="Models to trace.") +parser.add_argument('--print-models', action='store_true', help="Print all the available models names and exit") +parser.add_argument('-a', "--all", action='store_true', help="Export all available models") +parser.add_argument('-v', "--verbose", action='store_true', help="Set logging level to INFO") +parser.add_argument('-o', "--output-dir", default=".", type=str, help="Output directory for traced models") +parser.add_argument('-p', "--not-pretrained", dest="pretrained", action='store_false', + help="Whether the exported models should not be pretrained") +parser.add_argument('-t', '--template', default="", type=str, help="Template name of the model, as specified by pytorch-transformers") +parser.add_argument('--cpu', action='store_true', help="Force models to be exported for CPU device") +parser.add_argument('--input-size', type=int, default=512, help="Length of the input sequence") +parser.add_argument('--vocab', action='store_true', help="Export the vocab.dat file along with the model.") +parser.add_argument('--train', action='store_true', help="Prepare model for training") +parser.add_argument('--num-labels', type=int, default=2, help="For sequence classification only: number of classes") + +args = parser.parse_args() + +if args.verbose: + logging.basicConfig(level=logging.INFO) + +model_classes = { + "bert": M.BertModel, + "bert_masked_lm": M.BertForMaskedLM, + "bert_classif": M.BertForSequenceClassification, + "roberta": M.RobertaModel, + "roberta_masked_lm": M.RobertaForMaskedLM, + "gpt2": M.GPT2Model, + "gpt2_lm": M.GPT2LMHeadModel, +} + +def get_model_type(mname): + for key in default_templates: + if mname == key or mname.startswith(key + "_"): + return key + return "" + + +default_templates = { + "bert": "bert-base-uncased", + "roberta":"roberta-base", + "gpt2": "gpt2", +} + +tokenizers = { + "bert": M.BertTokenizer, + "roberta": M.RobertaTokenizer, + "gpt2": M.GPT2Tokenizer, +} + +if args.all: + args.models = model_classes.keys() + +if args.print_models: + print("*** Available models ***") + for key in model_classes: + print(key) + sys.exit(0) +elif not args.models: + sys.stderr.write("Please specify at least one model to be exported\n") + sys.exit(-1) + +device = 'cuda' if torch.cuda.is_available() and not args.cpu else 'cpu' +logging.info("Device: %s", device) + +if args.input_size > 512 or args.input_size <= 0: + logging.error("This input size is not supported: %d", args.input_size) + sys.exit(-1) + +logging.info("Input size: %d", args.input_size) + +# Example inputs +input_ids = torch.ones((1, args.input_size), dtype=torch.long, device=device) +att_mask = torch.ones_like(input_ids) +token_type_ids = torch.zeros_like(input_ids) +position_ids = torch.arange(args.input_size, dtype=torch.long, device=device).unsqueeze(0) + +for mname in args.models: + if mname not in model_classes: + logging.warn("model %s is unknown and will not be exported", mname) + continue + + model_type = get_model_type(mname) + + # Find appropriate template + if args.template: + mtemplate = args.template + else: + mtemplate = default_templates[model_type] + + # Additionnal parameters + kvargs = dict() + if mname in ["bert_classif"]: + kvargs["num_labels"] = args.num_labels + if model_type in ["bert", "roberta"]: + kvargs["output_hidden_states"] = True + + # Create the model + mclass = model_classes[mname] + logging.info("Model class: %s", mclass.__name__) + logging.info("Use template '%s'", mtemplate) + model = mclass.from_pretrained(mtemplate, torchscript=True, **kvargs) + + if not args.pretrained: + logging.info("Create model from scratch with the same config as the pretrained one") + model = mclass(model.config) + + model.to(device) + if not args.train: + model.eval() + else: + model.train() + + # Trace the model with the correct inputs + if mname in ["bert", "bert_masked_lm", "bert_classif", "roberta", "roberta_masked_lm"]: + traced_model = torch.jit.trace(model, (input_ids, token_type_ids, att_mask)) + elif mname in ["distilbert", "distilbert_masked_lm"]: + traced_model = torch.jit.trace(model, (input_ids, att_mask)) + elif mname in ["gpt2", "gpt2_lm"]: + # change order of positional arguments + def real_forward(self, i, p): + return self.p_forward(input_ids=i, position_ids=p) + setattr(mclass, 'p_forward', mclass.forward) + setattr(mclass, 'forward', real_forward) + + traced_model = torch.jit.trace(model, (input_ids, position_ids)) + else: + raise ValueError("there is no method to trace this model: %s" % mname) + + filename = os.path.join(args.output_dir, mname + + ("-" + mtemplate if args.template in mclass.pretrained_model_archive_map else "") + + ("-pretrained" if args.pretrained else "") + ".pt") + logging.info("Saving to %s", filename) + traced_model.save(filename) + + # Export vocab.dat + if args.vocab: + tokenizer = tokenizers[model_type].from_pretrained(mtemplate) + filename = os.path.join(args.output_dir, "vocab.dat") + + with open(filename, 'w') as f: + for i in range(len(tokenizer)): + word = tokenizer.convert_ids_to_tokens([i])[0] + f.write(word + "\t" + str(i) + "\n") + + logging.info("Vocabulary saved to %s", filename) + +logging.info("Done") From ec3438f1b6a0ece8ee85964343c7d5874edb7c69 Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 12 Sep 2019 14:46:27 +0200 Subject: [PATCH 09/14] Add classfication on hidden states to be able to use masked lm model for classif --- src/backends/torch/torchlib.cc | 10 +++++++++- src/backends/torch/torchlib.h | 2 ++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/backends/torch/torchlib.cc b/src/backends/torch/torchlib.cc index c7e3b4324..816dcde0e 100644 --- a/src/backends/torch/torchlib.cc +++ b/src/backends/torch/torchlib.cc @@ -90,6 +90,13 @@ namespace dd } } c10::IValue out_val = source.at(_classif_in); + if (_hidden_states) + { + // out_val is a tuple containing tensors of dimension n_batch * seq_len * n_features + // We want a tensor of size n_batch * n_features from the last hidden state + auto &elems = out_val.toTuple()->elements(); + out_val = elems.back().toTensor().slice(1, 0, 1).squeeze(1); + } if (_classif) { out_val = _classif->forward(to_tensor_safe(out_val)); @@ -122,7 +129,7 @@ namespace dd { if (!model._traced.empty()) _traced = torch::jit::load(model._traced, _device); - if (!model._weights.empty()) + if (!model._weights.empty() && _classif) torch::load(_classif, model._weights); } @@ -215,6 +222,7 @@ namespace dd _module._classif = nn::Linear(768, _nclasses); _module._classif->to(_device); + _module._hidden_states = true; _module._classif_in = 1; } } diff --git a/src/backends/torch/torchlib.h b/src/backends/torch/torchlib.h index 3c9ade21a..81bcceb17 100644 --- a/src/backends/torch/torchlib.h +++ b/src/backends/torch/torchlib.h @@ -54,6 +54,8 @@ namespace dd torch::Device _device; int _classif_in = 0; /**