Skip to content

Commit

Permalink
fix(torch): black&white image now working with crnn & dataaug
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Oct 28, 2023
1 parent c675876 commit 2b07002
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/backends/torch/native/templates/crnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,8 @@ namespace dd

if (_timesteps > 0)
{
at::Tensor dummy_img = torch::zeros({ 1, 3, _img_height, _img_width });
at::Tensor dummy_img
= torch::zeros({ 1, _input_channels, _img_height, _img_width });
at::Tensor dummy = _backbone(dummy_img).reshape({ 1, -1, _timesteps });
output_channel = dummy.size(1);
// XXX should use logger
Expand Down
16 changes: 16 additions & 0 deletions src/backends/torch/torchdataaug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,15 @@ namespace dd
if (_noise_params._prob == 0.0)
return;

// sanity check
bool img_is_bw = src.channels() == 1;
if (img_is_bw
&& (_noise_params._hist_eq || _noise_params._decolorize
|| _noise_params._jpg || _noise_params._convert_to_hsv
|| _noise_params._convert_to_lab))
throw std::runtime_error(
"Image has one channel when 3 channel dataaug is enabled");

if (_noise_params._rgb)
{
cv::Mat bgr;
Expand Down Expand Up @@ -847,6 +856,13 @@ namespace dd
if (_distort_params._prob == 0.0)
return;

bool img_is_bw = src.channels() == 1;
if (img_is_bw
&& (_distort_params._saturation || _distort_params._hue
|| _distort_params._channel_order))
throw std::runtime_error(
"Image has one channel when 3 channel dataaug is enabled");

if (_distort_params._rgb)
{
cv::Mat bgr;
Expand Down
7 changes: 5 additions & 2 deletions src/backends/torch/torchdataaug.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ namespace dd
class NoiseParams
{
public:
NoiseParams()
NoiseParams(bool bw = false)
: _hist_eq(!bw), _decolorize(!bw), _jpg(!bw), _convert_to_hsv(!bw),
_convert_to_lab(!bw)
{
}

Expand Down Expand Up @@ -192,7 +194,8 @@ namespace dd
class DistortParams
{
public:
DistortParams()
DistortParams(bool bw = false)
: _saturation(!bw), _hue(!bw), _channel_order(!bw)
{
}

Expand Down
4 changes: 2 additions & 2 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -756,15 +756,15 @@ namespace dd
ad_geometry.get("pad_mode").get<std::string>());
}
auto *img_ic = reinterpret_cast<ImgTorchInputFileConn *>(&inputc);
NoiseParams noise_params;
NoiseParams noise_params(img_ic->_bw);
noise_params._rgb = img_ic->_rgb;
APIData ad_noise = ad_mllib.getobj("noise");
if (!ad_noise.empty())
{
noise_params._prob = ad_noise.get("prob").get<double>();
this->_logger->info("noise: {}", noise_params._prob);
}
DistortParams distort_params;
DistortParams distort_params(img_ic->_bw);
distort_params._rgb = img_ic->_rgb;
APIData ad_distort = ad_mllib.getobj("distort");
if (!ad_distort.empty())
Expand Down
74 changes: 74 additions & 0 deletions tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1610,6 +1610,80 @@ TEST(torchapi, service_train_images_ctc_native)
fileops::remove_dir(resnet18_ocr_train_repo + "test_0.lmdb");
}

TEST(torchapi, service_train_ctc_native_bw)
{
// Just check that there are no errors when training in black&white
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
torch::manual_seed(torch_seed);
at::globalContext().setDeterministicCuDNN(true);

// Create service
JsonAPI japi;
std::string sname = "imgserv";
std::string jstr
= "{\"mllib\":\"torch\",\"description\":\"image\",\"type\":"
"\"supervised\",\"model\":{\"repository\":\""
+ resnet18_ocr_train_repo
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\","
"\"width\":112,\"height\":32,\"bw\":true,\"db\":true,\"ctc\":true},"
"\"mllib\":{\"template\":\"crnn\",\"gpu\":true,\"timesteps\":128}}}";
std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
ASSERT_EQ(created_str, joutstr);

// Train (few iterations)
std::string jtrainstr
= "{\"service\":\"imgserv\",\"async\":false,\"parameters\":{"
"\"mllib\":{\"solver\":{\"iterations\":3,\"base_lr\":1e-4"
",\"iter_size\":4,\"solver_type\":\"ADAM\",\"test_"
"interval\":200},\"net\":{\"batch_size\":32},"
"\"resume\":false,\"mirror\":false,\"rotate\":false,"
"\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
"false,\"persp_vertical\":false,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
"\"prob\":0.01},\"dataloader_threads\":4},"
"\"input\":{\"seed\":12345,\"db\":true,\"shuffle\":true},"
"\"output\":{\"measure\":[\"acc\"]}},\"data\":[\""
+ ocr_train_data + "\",\"" + ocr_test_data + "\"]}";
joutstr = japi.jrender(japi.service_train(jtrainstr));
JDoc jd;
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(201, jd["status"]["code"]);

// predict
std::string jpredictstr = "{\"service\":\"imgserv\",\"parameters\":{"
"\"output\":{\"best\":1,\"ctc\":true}},"
"\"data\":[\""
+ ocr_test_image + "\"]}";

joutstr = japi.jrender(japi.service_predict(jpredictstr));
jd = JDoc();
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(200, jd["status"]["code"]);

// remove files
std::unordered_set<std::string> lfiles;
fileops::list_directory(resnet18_ocr_train_repo, true, false, false, lfiles);
ASSERT_TRUE(
fileops::file_exists(resnet18_ocr_train_repo + "checkpoint-3.npt"));
for (std::string ff : lfiles)
{
if (ff.find("checkpoint") != std::string::npos
|| ff.find("solver") != std::string::npos)
remove(ff.c_str());
}
ASSERT_TRUE(
!fileops::file_exists(resnet18_ocr_train_repo + "checkpoint-3.npt"));

fileops::clear_directory(resnet18_ocr_train_repo + "train.lmdb");
fileops::clear_directory(resnet18_ocr_train_repo + "test_0.lmdb");
fileops::remove_dir(resnet18_ocr_train_repo + "train.lmdb");
fileops::remove_dir(resnet18_ocr_train_repo + "test_0.lmdb");
}

TEST(torchapi, service_publish_trained_model)
{
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
Expand Down

0 comments on commit 2b07002

Please sign in to comment.