diff --git a/src/backends/torch/torchinputconns.cc b/src/backends/torch/torchinputconns.cc index b5950f5ac..3bfaeaa63 100644 --- a/src/backends/torch/torchinputconns.cc +++ b/src/backends/torch/torchinputconns.cc @@ -2,7 +2,107 @@ namespace dd { +using namespace torch; + +// ===== TorchDataset + +void TorchDataset::add_batch(std::vector data, std::vector target) +{ + _batches.push_back(TorchBatch(data, target)); +} + +void TorchDataset::reset() +{ + _indices.clear(); + + for (int64_t i = 0; i < _batches.size(); ++i) { + _indices.push_back(i); + } + + if (_shuffle) + { + auto seed = _seed == -1 ? static_cast(time(NULL)) : _seed; + std::shuffle(_indices.begin(), _indices.end(), std::mt19937(seed)); + } +} + +// `request` holds the size of the batch +// Data selection and batch construction are done in this method +c10::optional TorchDataset::get_batch(BatchRequestType request) +{ + size_t count = request[0]; + count = count < _indices.size() ? count : _indices.size(); + + if (count == 0) { + return torch::nullopt; + } + + std::vector> data, target; + + while(count != 0) { + auto id = _indices.back(); + auto entry = _batches[id]; + + for (int i = 0; i < entry.data.size(); ++i) + { + while (i >= data.size()) + data.emplace_back(); + data[i].push_back(entry.data.at(i)); + } + for (int i = 0; i < entry.target.size(); ++i) + { + while (i >= target.size()) + target.emplace_back(); + target[i].push_back(entry.target.at(i)); + } + + _indices.pop_back(); + count--; + } + + std::vector data_tensors; + for (auto vec : data) + data_tensors.push_back(torch::stack(vec)); + + std::vector target_tensors; + for (auto vec : target) + target_tensors.push_back(torch::stack(vec)); + + return TorchBatch{ data_tensors, target_tensors }; +} + +TorchBatch TorchDataset::get_cached() { + reset(); + auto batch = get_batch({cache_size()}); + if (!batch) + throw InputConnectorInternalException("No data provided"); + return batch.value(); +} + +TorchDataset TorchDataset::split(double start, double stop) +{ + auto datasize = _batches.size(); + auto start_it = _batches.begin() + static_cast(datasize * start); + auto stop_it = _batches.end() - static_cast(datasize * (1 - stop)); + + TorchDataset new_dataset; + new_dataset._batches.insert(new_dataset._batches.end(), start_it, stop_it); + return new_dataset; +} + + +// ===== TxtTorchInputFileConn + +void TxtTorchInputFileConn::fillup_parameters(const APIData &ad_input) +{ + _width = this->_sequence; +} + void TxtTorchInputFileConn::transform(const APIData &ad) { + // if (_finetuning) + // XXX: Generating vocab from scratch is not currently + _generate_vocab = false; + try { TxtInputFileConn::transform(ad); @@ -12,68 +112,126 @@ void TxtTorchInputFileConn::transform(const APIData &ad) { throw; } + if (!_ordered_words || _characters) + throw InputConnectorBadParamException("Need ordered_words = true with backend torch"); + if (ad.has("parameters") && ad.getobj("parameters").has("input")) { APIData ad_input = ad.getobj("parameters").getobj("input"); - if (ad_input.has("width")) - _width = ad_input.get("width").get(); + fillup_parameters(ad_input); } - if (!_ordered_words || _characters) - throw InputConnectorBadParamException("Need ordered_words = true with backend torch"); + _cls_pos = _vocab.at("[CLS]")._pos; + _sep_pos = _vocab.at("[SEP]")._pos; + _unk_pos = _vocab.at("[UNK]")._pos; + _mask_id = _vocab.at("[MASK]")._pos; - int cls_pos = _vocab.at("[CLS]")._pos; - int sep_pos = _vocab.at("[SEP]")._pos; - int unk_pos = _vocab.at("[UNK]")._pos; + fill_dataset(_dataset, _txt); + if (!_test_txt.empty()) + fill_dataset(_test_dataset, _test_txt); +} - std::vector vids; - std::vector vmask; +TorchBatch TxtTorchInputFileConn::generate_masked_lm_batch(const TorchBatch &example) +{ + std::uniform_real_distribution uniform(0, 1); + std::uniform_int_distribution vocab_distrib(0, vocab_size() - 1); + Tensor input_ids = example.data.at(0).clone(); + // lm_labels: n_batch * sequence_length + // equals to input_ids where tokens are masked, and -1 otherwise + Tensor lm_labels = torch::ones_like(input_ids, TensorOptions(kLong)) * -1; - for (auto *te : _txt) + // mask random tokens + auto input_acc = input_ids.accessor(); + auto att_mask_acc = example.data.at(2).accessor(); + auto labels_acc = lm_labels.accessor(); + for (int i = 0; i < input_ids.size(0); ++i) + { + int j = 1; // skip [CLS] token + while (j < input_ids.size(1) && att_mask_acc[i][j] != 0) + { + double rand_num = uniform(_rng); + if (rand_num < _lm_params._change_prob && input_acc[i][j] != _sep_pos) + { + labels_acc[i][j] = input_acc[i][j]; + + rand_num = uniform(_rng); + if (rand_num < _lm_params._mask_prob) + { + input_acc[i][j] = mask_id(); + } + else if (rand_num < _lm_params._mask_prob + _lm_params._rand_prob) + { + input_acc[i][j] = vocab_distrib(_rng); + } + } + ++j; + } + } + + TorchBatch output; + output.target.push_back(lm_labels); + output.data.push_back(input_ids); + for (int i = 1; i < example.data.size(); ++i) + { + output.data.push_back(example.data[i]); + } + return output; +} + +void TxtTorchInputFileConn::fill_dataset(TorchDataset &dataset, + const std::vector*> &entries) +{ + for (auto *te : entries) { TxtOrderedWordsEntry *tow = static_cast(te); tow->reset(); - std::vector ids; - ids.push_back(cls_pos); + ids.push_back(_cls_pos); while(tow->has_elt()) { + if (ids.size() >= _width - 1) + break; + std::string word; double val; tow->get_next_elt(word, val); std::unordered_map::iterator it; - + if ((it = _vocab.find(word)) != _vocab.end()) { ids.push_back(it->second._pos); } else { - ids.push_back(unk_pos); + ids.push_back(_unk_pos); } } - ids.push_back(sep_pos); + ids.push_back(_sep_pos); at::Tensor ids_tensor = toLongTensor(ids); at::Tensor mask_tensor = torch::ones_like(ids_tensor); - // at::Tensor token_type_ids_tensor = torch::zeros_like(ids_tensor); + at::Tensor token_type_ids_tensor = torch::zeros_like(ids_tensor); int64_t padding_size = _width - ids_tensor.sizes().back(); ids_tensor = torch::constant_pad_nd( ids_tensor, at::IntList{0, padding_size}, 0); mask_tensor = torch::constant_pad_nd( mask_tensor, at::IntList{0, padding_size}, 0); - // token_type_ids_tensor = torch::constant_pad_nd( - // token_type_ids_tensor, at::IntList{0, padding_size}, 0); + token_type_ids_tensor = torch::constant_pad_nd( + token_type_ids_tensor, at::IntList{0, padding_size}, 0); - vids.push_back(ids_tensor); - vmask.push_back(mask_tensor); - } + std::vector target_vec; + int target_val = static_cast(tow->_target); + if (target_val != -1) + { + Tensor target_tensor = torch::full(1, target_val, torch::kLong); + target_vec.push_back(target_tensor); + } - _in = torch::stack(vids, 0); - _attention_mask = torch::stack(vmask, 0); + dataset.add_batch({ids_tensor, token_type_ids_tensor, mask_tensor}, std::move(target_vec)); + } } -} \ No newline at end of file +} diff --git a/src/backends/torch/torchinputconns.h b/src/backends/torch/torchinputconns.h index 583978f30..cb07f93d3 100644 --- a/src/backends/torch/torchinputconns.h +++ b/src/backends/torch/torchinputconns.h @@ -31,22 +31,81 @@ namespace dd { + typedef torch::data::Example, std::vector> TorchBatch; + + class TorchDataset : public torch::data::BatchDataset + > + { + private: + bool _shuffle = false; + long _seed = -1; + std::vector _indices; + + public: + /// Vector containing the whole dataset (the "cached data"). + std::vector _batches; + + + TorchDataset() {} + + void add_batch(std::vector data, std::vector target = {}); + + void reset(); + + /// Size of data loaded in memory + size_t cache_size() const { return _batches.size(); } + + c10::optional size() const override { + return cache_size(); + } + + bool empty() const { return cache_size() == 0; } + + c10::optional get_batch(BatchRequestType request) override; + + /// Returns a batch containing all the cached data + TorchBatch get_cached(); + + /// Split a percentage of this dataset + TorchDataset split(double start, double stop); + }; + + + struct MaskedLMParams + { + double _change_prob = 0.15; /**< When masked LM learning, probability of changing a token (mask/randomize/keep). */ + double _mask_prob = 0.8; /**< When masked LM learning, probability of masking a token. */ + double _rand_prob = 0.1; /**< When masked LM learning, probability of randomizing a token. */ + }; + + class TorchInputInterface { public: TorchInputInterface() {} TorchInputInterface(const TorchInputInterface &i) - : _in(i._in), _attention_mask(i._attention_mask) {} + : _finetuning(i._finetuning), + _lm_params(i._lm_params), + _dataset(i._dataset), + _test_dataset(i._test_dataset) { } ~TorchInputInterface() {} torch::Tensor toLongTensor(std::vector &values) { int64_t val_size = values.size(); - return torch::from_blob(&values[0], at::IntList{val_size}, at::kLong); + return torch::from_blob(&values[0], at::IntList{val_size}, at::kLong).clone(); } - at::Tensor _in; - at::Tensor _attention_mask; + TorchBatch generate_masked_lm_batch(const TorchBatch &example) { return {}; } + + int64_t mask_id() const { return 0; } + int64_t vocab_size() const { return 0; } + + TorchDataset _dataset; + TorchDataset _test_dataset; + + MaskedLMParams _lm_params; + bool _finetuning; }; class ImgTorchInputFileConn : public ImgInputFileConn, public TorchInputInterface @@ -55,7 +114,7 @@ namespace dd ImgTorchInputFileConn() :ImgInputFileConn() {} ImgTorchInputFileConn(const ImgTorchInputFileConn &i) - :ImgInputFileConn(i),TorchInputInterface(i) {} + :ImgInputFileConn(i),TorchInputInterface(i), _std{i._std} {} ~ImgTorchInputFileConn() {} // for API info only @@ -74,7 +133,7 @@ namespace dd { ImgInputFileConn::init(ad); } - + void transform(const APIData &ad) { try @@ -97,7 +156,6 @@ namespace dd } } - std::vector tensors; std::vector sizes{ _height, _width, 3 }; at::TensorOptions options(at::ScalarType::Byte); @@ -106,10 +164,8 @@ namespace dd imgt = imgt.toType(at::kFloat).permute({2, 0, 1}); if (_std != 1.0) imgt = imgt.mul(1. / _std); - tensors.push_back(imgt); + _dataset.add_batch({imgt}); } - - _in = torch::stack(tensors, 0); } public: @@ -131,6 +187,14 @@ namespace dd _width(i._width), _height(i._height) {} ~TxtTorchInputFileConn() {} + void init(const APIData &ad) + { + TxtInputFileConn::init(ad); + fillup_parameters(ad); + } + + void fillup_parameters(const APIData &ad_input); + // for API info only int width() const { @@ -143,12 +207,26 @@ namespace dd return _height; } + int64_t mask_id() const { return _mask_id; } + + int64_t vocab_size() const { return _vocab.size(); } + void transform(const APIData &ad); + TorchBatch generate_masked_lm_batch(const TorchBatch &example); + + void fill_dataset(TorchDataset &dataset, const std::vector*> &entries); public: + /** width of the input tensor */ int _width = 512; int _height = 0; + std::mt19937 _rng; + + int64_t _mask_id = -1; /**< ID of mask token in the vocabulary. */ + int64_t _cls_pos = -1; + int64_t _sep_pos = -1; + int64_t _unk_pos = -1; }; } // namespace dd -#endif // TORCHINPUTCONNS_H \ No newline at end of file +#endif // TORCHINPUTCONNS_H diff --git a/src/backends/torch/torchlib.cc b/src/backends/torch/torchlib.cc index b76b62fdc..daced2ae4 100644 --- a/src/backends/torch/torchlib.cc +++ b/src/backends/torch/torchlib.cc @@ -22,6 +22,9 @@ #include "torchlib.h" #include +#if !defined(CPU_ONLY) +#include +#endif #include "outputconnectorstrategy.h" @@ -29,9 +32,156 @@ using namespace torch; namespace dd { + inline void empty_cuda_cache() { + #if !defined(CPU_ONLY) + c10::cuda::CUDACachingAllocator::emptyCache(); + #endif + } + + void add_parameters(std::shared_ptr module, std::vector ¶ms, bool requires_grad = true) { + for (const auto &slot : module->get_parameters()) { + Tensor tensor = slot.value().toTensor(); + if (tensor.requires_grad() && requires_grad) + params.push_back(tensor); + } + for (auto child : module->get_modules()) { + add_parameters(child, params); + } + } + + /// Convert IValue to Tensor and throw an exception if the IValue is not a Tensor. + Tensor to_tensor_safe(const IValue &value) { + if (!value.isTensor()) + throw MLLibInternalException("Expected Tensor, found " + value.tagKind()); + return value.toTensor(); + } + + /// Convert id Tensor to one_hot Tensor + void fill_one_hot(Tensor &one_hot, Tensor ids, int nclasses) + { + one_hot.zero_(); + for (int i = 0; i < ids.size(0); ++i) + { + one_hot[i][ids[i].item()] = 1; + } + } + + Tensor to_one_hot(Tensor ids, int nclasses) + { + Tensor one_hot = torch::zeros(IntList{ids.size(0), nclasses}); + for (int i = 0; i < ids.size(0); ++i) + { + one_hot[i][ids[i].item()] = 1; + } + return one_hot; + } + + // ======= TORCH MODULE + + + TorchModule::TorchModule() : _device{"cpu"} {} + + c10::IValue TorchModule::forward(std::vector source) + { + if (_traced) + { + auto output = _traced->forward(source); + if (output.isTensorList()) { + auto &elems = output.toTensorList()->elements(); + source = std::vector(elems.begin(), elems.end()); + } + else if (output.isTuple()) { + auto &elems = output.toTuple()->elements(); + source = std::vector(elems.begin(), elems.end()); + } + else { + source = { output }; + } + } + c10::IValue out_val = source.at(_classif_in); + if (_hidden_states) + { + // out_val is a tuple containing tensors of dimension n_batch * seq_len * n_features + // We want a tensor of size n_batch * n_features from the last hidden state + auto &elems = out_val.toTuple()->elements(); + out_val = elems.back().toTensor().slice(1, 0, 1).squeeze(1); + } + if (_classif) + { + out_val = _classif->forward(to_tensor_safe(out_val)); + } + return out_val; + } + + void TorchModule::freeze_traced(bool freeze) + { + if (freeze != _freeze_traced) + { + _freeze_traced = freeze; + std::vector params; + add_parameters(_traced, params, false); + for (auto ¶m : params) + { + param.set_requires_grad(!freeze); + } + } + } + + std::vector TorchModule::parameters() + { + std::vector params; + if (_traced) + add_parameters(_traced, params); + if (_classif) + { + auto classif_params = _classif->parameters(); + params.insert(params.end(), classif_params.begin(), classif_params.end()); + } + return params; + } + + void TorchModule::save_checkpoint(TorchModel &model, const std::string &name) + { + if (_traced) + _traced->save(model._repo + "/checkpoint-" + name + ".pt"); + if (_classif) + torch::save(_classif, model._repo + "/checkpoint-" + name + ".ptw"); + } + + void TorchModule::load(TorchModel &model) + { + if (!model._traced.empty()) + _traced = torch::jit::load(model._traced, _device); + if (!model._weights.empty() && _classif) + torch::load(_classif, model._weights); + } + + void TorchModule::eval() { + if (_traced) + _traced->eval(); + if (_classif) + _classif->eval(); + } + + void TorchModule::train() { + if (_traced) + _traced->train(); + if (_classif) + _classif->train(); + } + + void TorchModule::free() + { + _traced = nullptr; + _classif = nullptr; + } + + + // ======= TORCHLIB + template TorchLib::TorchLib(const TorchModel &tmodel) - : MLLib(tmodel) + : MLLib(tmodel) { this->_libname = "torch"; } @@ -41,54 +191,369 @@ namespace dd : MLLib(std::move(tl)) { this->_libname = "torch"; - _traced = std::move(tl._traced); + _module = std::move(tl._module); + _template = tl._template; _nclasses = tl._nclasses; _device = tl._device; - _attention = tl._attention; + _masked_lm = tl._masked_lm; + _finetuning = tl._finetuning; } template - TorchLib::~TorchLib() + TorchLib::~TorchLib() { - + _module.free(); + empty_cuda_cache(); } /*- from mllib -*/ template void TorchLib::init_mllib(const APIData &lib_ad) { + bool classification = false; bool gpu = false; int gpuid = -1; + bool freeze_traced = false; + int embedding_size = 768; + std::string self_supervised = ""; - if (lib_ad.has("gpu")) { + if (lib_ad.has("template")) + _template = lib_ad.get("template").get(); + if (lib_ad.has("gpu")) gpu = lib_ad.get("gpu").get() && torch::cuda::is_available(); - } if (lib_ad.has("gpuid")) gpuid = lib_ad.get("gpuid").get(); - if (lib_ad.has("nclasses")) { + if (lib_ad.has("nclasses")) + { + classification = true; _nclasses = lib_ad.get("nclasses").get(); } + if (lib_ad.has("self_supervised")) + self_supervised = lib_ad.get("self_supervised").get(); + if (lib_ad.has("embedding_size")) + embedding_size = lib_ad.get("embedding_size").get(); + if (lib_ad.has("finetuning")) + _finetuning = lib_ad.get("finetuning").get(); + if (lib_ad.has("freeze_traced")) + freeze_traced = lib_ad.get("freeze_traced").get(); _device = gpu ? torch::Device(DeviceType::CUDA, gpuid) : torch::Device(DeviceType::CPU); + _module._device = _device; + + // Create the model + if (this->_mlmodel._traced.empty()) + throw MLLibInternalException("Use of libtorch backend without traced net is not supported yet"); - if (typeid(TInputConnectorStrategy) == typeid(TxtTorchInputFileConn)) { - _attention = true; + if (_template == "bert") + { + if (classification) + { + _module._classif = nn::Linear(embedding_size, _nclasses); + _module._classif->to(_device); + + _module._hidden_states = true; + _module._classif_in = 1; + } + else if (!self_supervised.empty()) + { + if (self_supervised != "mask") + { + throw MLLibBadParamException("self_supervised"); + } + this->_logger->info("Masked Language model"); + _masked_lm = true; + } + else + { + throw MLLibBadParamException("BERT only supports self-supervised or classification"); + } + } + else if (!_template .empty()) + { + throw MLLibBadParamException("template"); } - _traced = torch::jit::load(this->_mlmodel._model_file, _device); - _traced->eval(); + this->_logger->info("Loading ml model from file {}.", this->_mlmodel._traced); + if (!this->_mlmodel._weights.empty()) + this->_logger->info("Loading weights from file {}.", this->_mlmodel._weights); + _module.load(this->_mlmodel); + _module.freeze_traced(freeze_traced); + + this->_mltype = "classification"; } template - void TorchLib::clear_mllib(const APIData &ad) + void TorchLib::clear_mllib(const APIData &ad) { - + std::vector extensions{".json", ".pt", ".ptw"}; + fileops::remove_directory_files(this->_mlmodel._repo, extensions); + this->_logger->info("Torchlib service cleared"); } template - int TorchLib::train(const APIData &ad, APIData &out) + int TorchLib::train(const APIData &ad, APIData &out) { - + this->_tjob_running.store(true); + + TInputConnectorStrategy inputc(this->_inputc); + inputc._train = true; + inputc._finetuning = _finetuning; + + try + { + inputc.transform(ad); + } + catch (...) + { + throw; + } + + APIData ad_mllib = ad.getobj("parameters").getobj("mllib"); + + // solver params + int64_t iterations = 1; + std::string solver_type = "SGD"; + double base_lr = 0.0001; + int64_t batch_size = 1; + int64_t iter_size = 1; + int64_t test_batch_size = 1; + int64_t test_interval = 1; + int64_t save_period = 0; + + // logging parameters + int64_t log_batch_period = 20; + + if (ad_mllib.has("solver")) + { + APIData ad_solver = ad_mllib.getobj("solver"); + if (ad_solver.has("iterations")) + iterations = ad_solver.get("iterations").get(); + if (ad_solver.has("solver_type")) + solver_type = ad_solver.get("solver_type").get(); + if (ad_solver.has("base_lr")) + base_lr = ad_solver.get("base_lr").get(); + if (ad_solver.has("test_interval")) + test_interval = ad_solver.get("test_interval").get(); + if (ad_solver.has("iter_size")) + iter_size = ad_solver.get("iter_size").get(); + if (ad_solver.has("snapshot")) + save_period = ad_solver.get("snapshot").get(); + } + + if (ad_mllib.has("net")) + { + APIData ad_net = ad_mllib.getobj("net"); + if (ad_net.has("batch_size")) + batch_size = ad_net.get("batch_size").get(); + if (ad_net.has("test_batch_size")) + test_batch_size = ad_net.get("test_batch_size").get(); + } + + if (iter_size <= 0) + iter_size = 1; + + // create dataset for evaluation during training + TorchDataset eval_dataset; + if (!inputc._test_dataset.empty()) + { + eval_dataset = inputc._test_dataset; //.split(0, 0.1); + } + + // create solver + std::unique_ptr optimizer; + + if (solver_type == "ADAM") + optimizer = std::unique_ptr( + new optim::Adam(_module.parameters(), optim::AdamOptions(base_lr))); + else if (solver_type == "RMSPROP") + optimizer = std::unique_ptr( + new optim::RMSprop(_module.parameters(), optim::RMSpropOptions(base_lr))); + else if (solver_type == "ADAGRAD") + optimizer = std::unique_ptr( + new optim::Adagrad(_module.parameters(), optim::AdagradOptions(base_lr))); + else + { + if (solver_type != "SGD") + this->_logger->warn("Solver type {} not found, using SGD", solver_type); + optimizer = std::unique_ptr( + new optim::SGD(_module.parameters(), optim::SGDOptions(base_lr))); + } + // reload solver + if (!this->_mlmodel._sstate.empty()) + { + this->_logger->info("Reload solver from {}", this->_mlmodel._sstate); + torch::load(*optimizer, this->_mlmodel._sstate); + } + optimizer->zero_grad(); + _module.train(); + + // create dataloader + auto dataloader = torch::data::make_data_loader( + std::move(inputc._dataset), + data::DataLoaderOptions(batch_size) + ); + + this->_logger->info("Training for {} iterations", iterations); + int it = 0; + int batch_id = 0; + using namespace std::chrono; + + // it is the iteration count (not epoch) + while (it < iterations) + { + if (!this->_tjob_running.load()) + { + break; + } + + double train_loss = 0; + double avg_it_time = 0; + + for (TorchBatch batch : *dataloader) + { + auto tstart = system_clock::now(); + if (_masked_lm) + { + batch = inputc.generate_masked_lm_batch(batch); + } + std::vector in_vals; + for (Tensor tensor : batch.data) + in_vals.push_back(tensor.to(_device)); + Tensor y = batch.target.at(0).to(_device); + + Tensor y_pred; + try + { + y_pred = to_tensor_safe(_module.forward(in_vals)); + } + catch (std::exception &e) + { + throw MLLibInternalException(std::string("Libtorch error:") + e.what()); + } + + // As CrossEntropy is not available (Libtorch 1.1) we use nllloss + log_softmax + Tensor loss; + if (_masked_lm) + { + // Convert [n_batch, sequence_length, vocab_size] to [n_batch * sequence_length, vocab_size] + // + ignore non-masked tokens (== -1 in target) + loss = torch::nll_loss( + torch::log_softmax(y_pred.view(IntList{-1, y_pred.size(2)}), 1), + y.view(IntList{-1}), + {}, Reduction::Mean, -1 + ); + } + else + { + loss = torch::nll_loss(torch::log_softmax(y_pred, 1), y.view(IntList{-1})); + } + if (iter_size > 1) + loss /= iter_size; + + double loss_val = loss.item(); + train_loss += loss_val; + loss.backward(); + auto tstop = system_clock::now(); + avg_it_time += duration_cast(tstop - tstart).count(); + + if ((batch_id + 1) % iter_size == 0) + { + if (!this->_tjob_running.load()) + { + break; + } + optimizer->step(); + optimizer->zero_grad(); + avg_it_time /= iter_size; + this->add_meas("learning_rate", base_lr); + this->add_meas("iteration", it); + this->add_meas("iter_time", avg_it_time); + this->add_meas("remain_time", avg_it_time * iter_size * (iterations - it) / 1000.0); + this->add_meas("train_loss", train_loss); + this->add_meas_per_iter("learning_rate", base_lr); + this->add_meas_per_iter("train_loss", train_loss); + int64_t elapsed_it = it + 1; + if (log_batch_period != 0 && elapsed_it % log_batch_period == 0) + { + this->_logger->info("Iteration {}/{}: loss is {}", elapsed_it, iterations, train_loss); + } + avg_it_time = 0; + train_loss = 0; + + if (elapsed_it % test_interval == 0 && elapsed_it != iterations && !eval_dataset.empty()) + { + // Free memory + loss = torch::empty(1); + y_pred = torch::empty(1); + y = torch::empty(1); + in_vals.clear(); + + APIData meas_out; + this->_logger->info("Start test"); + test(ad, inputc, eval_dataset, test_batch_size, meas_out); + APIData meas_obj = meas_out.getobj("measure"); + std::vector meas_names = meas_obj.list_keys(); + + for (auto name : meas_names) + { + if (name != "cmdiag" && name != "cmfull" && name != "labels") + { + double mval = meas_obj.get(name).get(); + this->_logger->info("{}={}", name, mval); + this->add_meas(name, mval); + this->add_meas_per_iter(name, mval); + } + else if (name == "cmdiag") + { + std::vector mdiag = meas_obj.get(name).get>(); + std::vector cnames; + std::string mdiag_str; + for (size_t i=0; i_mlmodel.get_hcorresp(i) + ":" + std::to_string(mdiag.at(i)) + " "; + this->add_meas_per_iter(name+'_'+this->_mlmodel.get_hcorresp(i), mdiag.at(i)); + cnames.push_back(this->_mlmodel.get_hcorresp(i)); + } + this->_logger->info("{}=[{}]", name, mdiag_str); + this->add_meas(name, mdiag, cnames); + } + } + } + + if ((save_period != 0 && elapsed_it % save_period == 0) || elapsed_it == iterations) + { + this->_logger->info("Saving checkpoint after {} iterations", elapsed_it); + _module.save_checkpoint(this->_mlmodel, std::to_string(elapsed_it)); + // Save optimizer + torch::save(*optimizer, this->_mlmodel._repo + "/solver-" + std::to_string(elapsed_it) + ".pt"); + } + ++it; + + if (it >= iterations) + break; + } + + ++batch_id; + } + } + + if (!this->_tjob_running.load()) + { + this->_logger->info("Training job interrupted at iteration {}", it); + empty_cuda_cache(); + return -1; + } + + test(ad, inputc, inputc._test_dataset, test_batch_size, out); + empty_cuda_cache(); + + // Update model after training + this->_mlmodel.read_from_repository(this->_logger); + this->_mlmodel.read_corresp_file(); + + inputc.response_params(out); + this->_logger->info("Training done."); + return 0; } template @@ -106,33 +571,34 @@ namespace dd } torch::Device cpu("cpu"); - std::vector in_vals; - in_vals.push_back(inputc._in.to(_device)); + _module.eval(); - if (_attention) { - // token_type_ids - in_vals.push_back(torch::zeros_like(inputc._in, at::kLong).to(_device)); - in_vals.push_back(inputc._attention_mask.to(_device)); + if (output_params.has("measure")) + { + APIData meas_out; + test(ad, inputc, inputc._dataset, 1, meas_out); + meas_out.erase("iteration"); + meas_out.erase("train_loss"); + out.add("measure", meas_out.getobj("measure")); + empty_cuda_cache(); + return 0; } + inputc._dataset.reset(); + std::vector in_vals; + for (Tensor tensor : inputc._dataset.get_cached().data) + in_vals.push_back(tensor.to(_device)); Tensor output; try { - c10::IValue out_val = _traced->forward(in_vals); - if (out_val.isTuple()) { - out_val = out_val.toTuple()->elements()[0]; - } - if (!out_val.isTensor()) { - throw MLLibInternalException("Model returned an invalid output. Please check your model."); - } - output = out_val.toTensor().to(at::kFloat); + output = to_tensor_safe(_module.forward(in_vals)); + output = torch::softmax(output, 1).to(cpu); } catch (std::exception &e) { throw MLLibInternalException(std::string("Libtorch error:") + e.what()); } - output = torch::softmax(output, 1).to(cpu); - + // Output std::vector results_ads; @@ -170,10 +636,102 @@ namespace dd outputc.finalize(output_params, out, static_cast(&this->_mlmodel)); out.add("status", 0); - return 0; } + template + int TorchLib::test(const APIData &ad, + TInputConnectorStrategy &inputc, + TorchDataset &dataset, + int batch_size, + APIData &out) + { + APIData ad_res; + APIData ad_out = ad.getobj("parameters").getobj("output"); + int nclasses = _masked_lm ? inputc.vocab_size() : _nclasses; + + // confusion matrix is irrelevant to masked_lm training + if (_masked_lm && ad_out.has("measure")) + { + auto meas = ad_out.get("measure").get>(); + std::vector::iterator it; + if ((it = std::find(meas.begin(), meas.end(), "cmfull")) != meas.end()) + meas.erase(it); + if ((it = std::find(meas.begin(), meas.end(), "cmdiag")) != meas.end()) + meas.erase(it); + ad_out.add("measure", meas); + } + + auto dataloader = torch::data::make_data_loader( + dataset, + data::DataLoaderOptions(batch_size) + ); + torch::Device cpu("cpu"); + + _module.eval(); + int entry_id = 0; + for (TorchBatch batch : *dataloader) + { + if (_masked_lm) + { + batch = inputc.generate_masked_lm_batch(batch); + } + std::vector in_vals; + for (Tensor tensor : batch.data) + in_vals.push_back(tensor.to(_device)); + + Tensor output; + try + { + output = to_tensor_safe(_module.forward(in_vals)); + } + catch (std::exception &e) + { + throw MLLibInternalException(std::string("Libtorch error:") + e.what()); + } + + if (batch.target.empty()) + throw MLLibBadParamException("Missing label on data while testing"); + Tensor labels = batch.target[0].view(IntList{-1}); + if (_masked_lm) + { + // Convert [n_batch, sequence_length, vocab_size] to [n_batch * sequence_length, vocab_size] + output = output.view(IntList{-1, output.size(2)}); + } + output = torch::softmax(output, 1).to(cpu); + auto output_acc = output.accessor(); + auto labels_acc = labels.accessor(); + + for (int j = 0; j < labels.size(0); ++j) + { + if (_masked_lm && labels_acc[j] == -1) + continue; + + APIData bad; + std::vector predictions; + for (int c = 0; c < nclasses; c++) + { + predictions.push_back(output_acc[j][c]); + } + bad.add("target", static_cast(labels_acc[j])); + bad.add("pred", predictions); + ad_res.add(std::to_string(entry_id), bad); + ++entry_id; + } + // this->_logger->info("Testing: {}/{} entries processed", entry_id, test_size); + } + + ad_res.add("iteration",this->get_meas("iteration")); + ad_res.add("train_loss",this->get_meas("train_loss")); + std::vector clnames; + for (int i=0;i< nclasses;i++) + clnames.push_back(this->_mlmodel.get_hcorresp(i)); + ad_res.add("clnames", clnames); + ad_res.add("nclasses", nclasses); + ad_res.add("batch_size", entry_id); // here batch_size = tested entries count + SupervisedOutput::measure(ad_res, ad_out, out); + return 0; + } template class TorchLib; template class TorchLib; diff --git a/src/backends/torch/torchlib.h b/src/backends/torch/torchlib.h index ddb6809b5..ef7de00c7 100644 --- a/src/backends/torch/torchlib.h +++ b/src/backends/torch/torchlib.h @@ -22,6 +22,8 @@ #ifndef TORCHLIB_H #define TORCHLIB_H +#include + #include #include "apidata.h" @@ -32,6 +34,41 @@ namespace dd { + // TODO: Make TorchModule inherit torch::nn::Module ? And use the TORCH_MODULE macro + class TorchModule { + public: + TorchModule(); + + c10::IValue forward(std::vector source); + + void freeze_traced(bool freeze); + + std::vector parameters(); + + /** Save traced module to checkpoint-[name].pt, and custom parts weights + * to checkpoint-[name].ptw */ + // (Actually only _classif is saved in the .ptw) + void save_checkpoint(TorchModel &model, const std::string &name); + + /** Load traced module from .pt and custom parts weights from .ptw */ + void load(TorchModel &model); + + void eval(); + void train(); + + void free(); + public: + std::shared_ptr _traced; + torch::nn::Linear _classif = nullptr; + + torch::Device _device; + int _classif_in = 0; /**