Skip to content

Commit

Permalink
fix(torch): predictions handled correctly when data count > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Jul 20, 2021
1 parent 4f17897 commit 5a95c29
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 13 deletions.
20 changes: 8 additions & 12 deletions src/backends/torch/torchdataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,25 +217,20 @@ namespace dd
write_tensors_to_db(data, target);
}

void TorchDataset::reset(bool shuffle, db::Mode dbmode)
void TorchDataset::reset(db::Mode dbmode)
{
std::lock_guard<std::mutex> guard(_mutex);
_shuffle = shuffle;
size_t data_size = 0;

if (!_db)
{
if (!_lfiles.empty()) // list of files
{
_indices = std::vector<int64_t>(_lfiles.size());
std::iota(std::begin(_indices), std::end(_indices), 0);
data_size = _lfiles.size();
}
else if (!_batches.empty())
{
_indices = std::vector<int64_t>(_batches.size());
std::iota(std::begin(_indices), std::end(_indices), 0);
}
else
{
_indices.clear();
data_size = _batches.size();
}
}
else // below db case
Expand All @@ -249,10 +244,11 @@ namespace dd
if (!_dbCursor)
_dbCursor = _dbData->NewCursor();

_indices = std::vector<int64_t>(_dbData->Count());
std::iota(std::begin(_indices), std::end(_indices), 0);
data_size = _dbData->Count();
}

_indices.resize(data_size);
std::iota(std::rbegin(_indices), std::rend(_indices), 0);
if (_shuffle)
{
std::shuffle(_indices.begin(), _indices.end(), _rng);
Expand Down
8 changes: 7 additions & 1 deletion src/backends/torch/torchdataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,13 @@ namespace dd
/**
* \brief reset dataset reading status : ie start new epoch
*/
void reset(bool shuffle = true, db::Mode dbmode = db::READ);
void reset(db::Mode dbmode = db::READ);

void reset(bool shuffle, db::Mode dbMode = db::READ)
{
_shuffle = shuffle;
reset(dbMode);
}

/**
* \brief setter for _shuffle
Expand Down
28 changes: 28 additions & 0 deletions tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,34 @@ TEST(torchapi, service_predict)
ASSERT_EQ(200, jd["status"]["code"]);
ASSERT_TRUE(jd["body"]["predictions"].IsArray());
ASSERT_EQ(jd["body"]["predictions"][0]["classes"].Size(), 8);

// batch size == 2
jpredictstr
= "{\"service\":\"imgserv\",\"parameters\":{\"input\":{\"height\":224,"
"\"width\":224},\"mllib\":{\"net\":{\"test_batch_size\":2}},"
"\"output\":{\"best\":1}},\"data\":[\""
+ incept_repo + "cat.jpg\",\"" + incept_repo + "dog.jpg\"]}";
joutstr = japi.jrender(japi.service_predict(jpredictstr));
jd = JDoc();
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(200, jd["status"]["code"]);
ASSERT_TRUE(jd["body"]["predictions"].IsArray());
ASSERT_EQ(jd["body"]["predictions"].Size(), 2);
cl1 = jd["body"]["predictions"][0]["classes"][0]["cat"].GetString();
std::string cl2
= jd["body"]["predictions"][1]["classes"][0]["cat"].GetString();
std::string cl_cat = jd["body"]["predictions"][0]["uri"].GetString()
== incept_repo + "cat.jpg"
? cl1
: cl2;
std::string cl_dog = jd["body"]["predictions"][1]["uri"].GetString()
== incept_repo + "dog.jpg"
? cl2
: cl1;
ASSERT_EQ(cl_cat, "n02123045 tabby, tabby cat");
ASSERT_EQ(cl_dog, "n02096051 Airedale, Airedale terrier");
}

TEST(torchapi, service_predict_native_bw)
Expand Down

0 comments on commit 5a95c29

Please sign in to comment.