Skip to content

Commit

Permalink
Add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob committed Sep 26, 2019
1 parent bdb5838 commit 7920fb7
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/backends/torch/torchinputconns.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ void TorchDataset::reset()
}
}

// `request` holds the size of the batch
// Data selection and batch construction are done in this method
c10::optional<TorchBatch> TorchDataset::get_batch(BatchRequestType request)
{
size_t count = request[0];
Expand Down Expand Up @@ -134,6 +136,8 @@ TorchBatch TxtTorchInputFileConn::generate_masked_lm_batch(const TorchBatch &exa
std::uniform_real_distribution<double> uniform(0, 1);
std::uniform_int_distribution<int64_t> vocab_distrib(0, vocab_size() - 1);
Tensor input_ids = example.data.at(0).clone();
// lm_labels: n_batch * sequence_length
// equals to input_ids where tokens are masked, and -1 otherwise
Tensor lm_labels = torch::ones_like(input_ids, TensorOptions(kLong)) * -1;

// mask random tokens
Expand Down
1 change: 1 addition & 0 deletions src/backends/torch/torchinputconns.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ namespace dd
std::vector<int64_t> _indices;

public:
/// Vector containing the whole dataset (the "cached data").
std::vector<TorchBatch> _batches;


Expand Down
5 changes: 5 additions & 0 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ namespace dd
int batch_id = 0;
using namespace std::chrono;

// it is the iteration count (not epoch)
while (it < iterations)
{
if (!this->_tjob_running.load())
Expand Down Expand Up @@ -430,9 +431,12 @@ namespace dd
throw MLLibInternalException(std::string("Libtorch error:") + e.what());
}

// As CrossEntropy is not available (Libtorch 1.1) we use nllloss + log_softmax
Tensor loss;
if (_masked_lm)
{
// Convert [n_batch, sequence_length, vocab_size] to [n_batch * sequence_length, vocab_size]
// + ignore non-masked tokens (== -1 in target)
loss = torch::nll_loss(
torch::log_softmax(y_pred.view(IntList{-1, y_pred.size(2)}), 1),
y.view(IntList{-1}),
Expand Down Expand Up @@ -691,6 +695,7 @@ namespace dd
Tensor labels = batch.target[0].view(IntList{-1});
if (_masked_lm)
{
// Convert [n_batch, sequence_length, vocab_size] to [n_batch * sequence_length, vocab_size]
output = output.view(IntList{-1, output.size(2)});
}
output = torch::softmax(output, 1).to(cpu);
Expand Down
6 changes: 5 additions & 1 deletion src/backends/torch/torchlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

namespace dd
{
// TODO: Make TorchModule inherit torch::nn::Module ? And use the TORCH_MODULE macro
class TorchModule {
public:
TorchModule();
Expand All @@ -44,8 +45,12 @@ namespace dd

std::vector<torch::Tensor> parameters();

/** Save traced module to checkpoint-[name].pt, and custom parts weights
* to checkpoint-[name].ptw */
// (Actually only _classif is saved in the .ptw)
void save_checkpoint(TorchModel &model, const std::string &name);

/** Load traced module from .pt and custom parts weights from .ptw */
void load(TorchModel &model);

void eval();
Expand All @@ -58,7 +63,6 @@ namespace dd

torch::Device _device;
int _classif_in = 0; /**<id of the input of the classification layer */
// XXX: This parameter is too specific
bool _hidden_states = false; /**< Take BERT hidden states as input. */
private:
bool _freeze_traced = false; /**< Freeze weights of the traced module */
Expand Down

0 comments on commit 7920fb7

Please sign in to comment.