Skip to content

Commit

Permalink
fix(torch): Fix update metrics and solver options when resuming
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Apr 12, 2022
1 parent b88f22a commit 9b0019f
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/backends/torch/torchsolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,6 @@ namespace dd
size_t end = sstate.rfind(".");
int it = std::stoi(sstate.substr(start, end - start));
_logger->info("Restarting optimization from iter {}", it);
_logger->info("loading " + sstate);
try
{
torch::load(*_optimizer, sstate, device);
Expand Down Expand Up @@ -315,6 +314,8 @@ namespace dd
"Optimizer not created at resume time, this means that there are "
"no param.solver api data");
}

int it = 0;
// reload solver if asked for and set it value accordingly
if (ad_mllib.has("resume") && ad_mllib.get("resume").get<bool>())
{
Expand All @@ -328,7 +329,7 @@ namespace dd
else
try
{
return load(mlmodel._sstate, main_device);
it = load(mlmodel._sstate, main_device);
}
catch (std::exception &e)
{
Expand Down Expand Up @@ -385,7 +386,7 @@ namespace dd
}

override_options();
return 0;
return it;
}

void TorchSolver::override_options()
Expand Down
127 changes: 127 additions & 0 deletions tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,133 @@ TEST(torchapi, service_train_images)
fileops::remove_dir(resnet50_train_repo + "test_0.lmdb");
}

TEST(torchapi, service_train_resume)
{
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
torch::manual_seed(torch_seed);
at::globalContext().setDeterministicCuDNN(true);

std::string iterations_resnet50 = "3";

// Create 2 images dataset
std::string resnet50_2_images_data = resnet50_train_repo + "2images/";
fileops::create_dir(resnet50_2_images_data, 0775);
fileops::create_dir(resnet50_2_images_data + "cats/", 0775);
fileops::create_dir(resnet50_2_images_data + "dogs/", 0775);
fileops::copy_file(resnet50_train_data + "cats/cat.10347.jpg",
resnet50_2_images_data + "cats/cat.10347.jpg");
fileops::copy_file(resnet50_train_data + "dogs/dog.10537.jpg",
resnet50_2_images_data + "dogs/dog.10537.jpg");

// Create service
JsonAPI japi;
std::string sname = "imgserv";
std::string jstr
= "{\"mllib\":\"torch\",\"description\":\"image\",\"type\":"
"\"supervised\",\"model\":{\"repository\":\""
+ resnet50_train_repo
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\","
"\"width\":256,\"height\":256,\"db\":true},\"mllib\":{\"nclasses\":"
"2,\"finetuning\":true,\"gpu\":true}}}";
std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
ASSERT_EQ(created_str, joutstr);

// Train
std::string jtrainstr
= "{\"service\":\"imgserv\",\"async\":false,\"parameters\":{"
"\"mllib\":{\"solver\":{\"iterations\":"
+ iterations_resnet50 + ",\"base_lr\":" + torch_lr
+ ",\"iter_size\":1,\"solver_type\":\"ADAM\",\"test_"
"interval\":2,\"snapshot\":2,\"base_lr\":1e-5},\"net\":{\"batch_"
"size\":2},\"resume\":false},"
"\"input\":{\"seed\":12345,\"db\":true,\"shuffle\":false},"
"\"output\":{\"measure\":[\"f1\",\"acc\"]}},\"data\":[\""
+ resnet50_2_images_data + "\",\"" + resnet50_2_images_data + "\"]}";
joutstr = japi.jrender(japi.service_train(jtrainstr));
JDoc jd;
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(201, jd["status"]["code"]);

// First predict
std::string jpredictstr = "{\"service\":\"imgserv\",\"parameters\":{"
"\"mllib\":{\"extract_layer\":\"last\"}},"
"\"data\":[\""
+ resnet50_test_image + "\"]}";
std::string out1 = japi.jrender(japi.service_predict(jpredictstr));
jd = JDoc();
jd.Parse<rapidjson::kParseNanAndInfFlag>(out1.c_str());
std::vector<double> out1_vals;
auto &jvals1 = jd["body"]["predictions"][0]["vals"];
for (size_t i = 0; i < jvals1.Size(); i++)
{
out1_vals.push_back(jvals1[i].GetDouble());
}

remove((resnet50_train_repo + "checkpoint-" + iterations_resnet50 + ".ptw")
.c_str());
remove((resnet50_train_repo + "checkpoint-" + iterations_resnet50 + ".pt")
.c_str());
remove(
(resnet50_train_repo + "solver-" + iterations_resnet50 + ".pt").c_str());

// Recreate service
japi.service_delete(sname, "");

joutstr = japi.jrender(japi.service_create(sname, jstr));
ASSERT_EQ(created_str, joutstr);

// Resume
jtrainstr = "{\"service\":\"imgserv\",\"async\":false,\"parameters\":{"
"\"mllib\":{\"solver\":{\"iterations\":"
+ iterations_resnet50 + ",\"base_lr\":" + torch_lr
+ ",\"iter_size\":1,\"solver_type\":\"ADAM\",\"test_"
"interval\":2,\"base_lr\":1e-5},\"net\":{\"batch_size\":2},"
"\"resume\":true},\"input\":{\"seed\":12345,\"db\":true,"
"\"shuffle\":false},"
"\"output\":{\"measure\":[\"f1\",\"acc\"]}},\"data\":[\""
+ resnet50_2_images_data + "\",\"" + resnet50_2_images_data
+ "\"]}";
joutstr = japi.jrender(japi.service_train(jtrainstr));
jd = JDoc();
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(201, jd["status"]["code"]);

// Predict
std::string out2 = japi.jrender(japi.service_predict(jpredictstr));
std::cout << "out1=" << out1 << std::endl;
std::cout << "out2=" << out2 << std::endl;
jd = JDoc();
jd.Parse<rapidjson::kParseNanAndInfFlag>(out2.c_str());
auto &jvals2 = jd["body"]["predictions"][0]["vals"];
for (size_t i = 0; i < jvals2.Size(); i++)
{
ASSERT_TRUE(abs(jvals2[i].GetDouble() - out1_vals.at(i)) < 0.001);
}

// remove files
std::unordered_set<std::string> lfiles;
fileops::list_directory(resnet50_train_repo, true, false, false, lfiles);
for (std::string ff : lfiles)
{
if (ff.find("checkpoint") != std::string::npos
|| ff.find("solver") != std::string::npos)
remove(ff.c_str());
}
ASSERT_TRUE(!fileops::file_exists(resnet50_train_repo + "checkpoint-"
+ iterations_resnet50 + ".ptw"));
ASSERT_TRUE(!fileops::file_exists(resnet50_train_repo + "checkpoint-"
+ iterations_resnet50 + ".pt"));

fileops::clear_directory(resnet50_train_repo + "train.lmdb");
fileops::clear_directory(resnet50_train_repo + "test_0.lmdb");
fileops::remove_dir(resnet50_train_repo + "train.lmdb");
fileops::remove_dir(resnet50_train_repo + "test_0.lmdb");
}

TEST(torchapi, service_train_image_segmentation_deeplabv3)
{
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
Expand Down

0 comments on commit 9b0019f

Please sign in to comment.