Skip to content

Commit

Permalink
Free memory when it's needed
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob committed Sep 25, 2019
1 parent 54c2830 commit ba54cba
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
31 changes: 27 additions & 4 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "torchlib.h"

#include <torch/script.h>
#include <c10/cuda/CUDACachingAllocator.h>

#include "outputconnectorstrategy.h"

Expand Down Expand Up @@ -161,6 +162,12 @@ namespace dd
_classif->train();
}

void TorchModule::free()
{
_traced = nullptr;
_classif = nullptr;
}


// ======= TORCHLIB

Expand All @@ -187,7 +194,8 @@ namespace dd
template <class TInputConnectorStrategy, class TOutputConnectorStrategy, class TMLModel>
TorchLib<TInputConnectorStrategy, TOutputConnectorStrategy, TMLModel>::~TorchLib()
{

_module.free();
c10::cuda::CUDACachingAllocator::emptyCache();
}

/*- from mllib -*/
Expand Down Expand Up @@ -376,7 +384,9 @@ namespace dd
while (it < iterations)
{
if (!this->_tjob_running.load())
{
break;
}

double train_loss = 0;
double avg_it_time = 0;
Expand Down Expand Up @@ -427,6 +437,10 @@ namespace dd

if ((batch_id + 1) % iter_size == 0)
{
if (!this->_tjob_running.load())
{
break;
}
optimizer->step();
optimizer->zero_grad();
avg_it_time /= iter_size;
Expand All @@ -447,7 +461,14 @@ namespace dd

if (elapsed_it % test_interval == 0 && elapsed_it != iterations && !eval_dataset.empty())
{
// Free memory
loss = torch::empty(1);
y_pred = torch::empty(1);
y = torch::empty(1);
in_vals.clear();

APIData meas_out;
this->_logger->info("Start test");
test(ad, inputc, eval_dataset, test_batch_size, meas_out);
APIData meas_obj = meas_out.getobj("measure");
std::vector<std::string> meas_names = meas_obj.list_keys();
Expand Down Expand Up @@ -497,11 +518,13 @@ namespace dd

if (!this->_tjob_running.load())
{
this->_logger->info("Training stopped");
return 0;
this->_logger->info("Training job interrupted at iteration {}", it);
c10::cuda::CUDACachingAllocator::emptyCache();
return -1;
}

test(ad, inputc, inputc._test_dataset, test_batch_size, out);
c10::cuda::CUDACachingAllocator::emptyCache();

// Update model after training
this->_mlmodel.read_from_repository(this->_logger);
Expand Down Expand Up @@ -536,6 +559,7 @@ namespace dd
meas_out.erase("iteration");
meas_out.erase("train_loss");
out.add("measure", meas_out.getobj("measure"));
c10::cuda::CUDACachingAllocator::emptyCache();
return 0;
}

Expand Down Expand Up @@ -591,7 +615,6 @@ namespace dd
outputc.finalize(output_params, out, static_cast<MLModel*>(&this->_mlmodel));

out.add("status", 0);

return 0;
}

Expand Down
2 changes: 2 additions & 0 deletions src/backends/torch/torchlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ namespace dd

void eval();
void train();

void free();
public:
std::shared_ptr<torch::jit::script::Module> _traced;
torch::nn::Linear _classif = nullptr;
Expand Down

0 comments on commit ba54cba

Please sign in to comment.