From 2b070027944affedc753b9a88c7148a4f9fa71e3 Mon Sep 17 00:00:00 2001 From: Louis Jean Date: Fri, 27 Oct 2023 15:48:14 +0000 Subject: [PATCH] fix(torch): black&white image now working with crnn & dataaug --- src/backends/torch/native/templates/crnn.cc | 3 +- src/backends/torch/torchdataaug.cc | 16 +++++ src/backends/torch/torchdataaug.h | 7 +- src/backends/torch/torchlib.cc | 4 +- tests/ut-torchapi.cc | 74 +++++++++++++++++++++ 5 files changed, 99 insertions(+), 5 deletions(-) diff --git a/src/backends/torch/native/templates/crnn.cc b/src/backends/torch/native/templates/crnn.cc index 8ad78bad2..69fa25110 100644 --- a/src/backends/torch/native/templates/crnn.cc +++ b/src/backends/torch/native/templates/crnn.cc @@ -306,7 +306,8 @@ namespace dd if (_timesteps > 0) { - at::Tensor dummy_img = torch::zeros({ 1, 3, _img_height, _img_width }); + at::Tensor dummy_img + = torch::zeros({ 1, _input_channels, _img_height, _img_width }); at::Tensor dummy = _backbone(dummy_img).reshape({ 1, -1, _timesteps }); output_channel = dummy.size(1); // XXX should use logger diff --git a/src/backends/torch/torchdataaug.cc b/src/backends/torch/torchdataaug.cc index 6cc4640ca..1d26d0818 100644 --- a/src/backends/torch/torchdataaug.cc +++ b/src/backends/torch/torchdataaug.cc @@ -804,6 +804,15 @@ namespace dd if (_noise_params._prob == 0.0) return; + // sanity check + bool img_is_bw = src.channels() == 1; + if (img_is_bw + && (_noise_params._hist_eq || _noise_params._decolorize + || _noise_params._jpg || _noise_params._convert_to_hsv + || _noise_params._convert_to_lab)) + throw std::runtime_error( + "Image has one channel when 3 channel dataaug is enabled"); + if (_noise_params._rgb) { cv::Mat bgr; @@ -847,6 +856,13 @@ namespace dd if (_distort_params._prob == 0.0) return; + bool img_is_bw = src.channels() == 1; + if (img_is_bw + && (_distort_params._saturation || _distort_params._hue + || _distort_params._channel_order)) + throw std::runtime_error( + "Image has one channel when 3 channel dataaug is enabled"); + if (_distort_params._rgb) { cv::Mat bgr; diff --git a/src/backends/torch/torchdataaug.h b/src/backends/torch/torchdataaug.h index fceee521a..c719fbc60 100644 --- a/src/backends/torch/torchdataaug.h +++ b/src/backends/torch/torchdataaug.h @@ -156,7 +156,9 @@ namespace dd class NoiseParams { public: - NoiseParams() + NoiseParams(bool bw = false) + : _hist_eq(!bw), _decolorize(!bw), _jpg(!bw), _convert_to_hsv(!bw), + _convert_to_lab(!bw) { } @@ -192,7 +194,8 @@ namespace dd class DistortParams { public: - DistortParams() + DistortParams(bool bw = false) + : _saturation(!bw), _hue(!bw), _channel_order(!bw) { } diff --git a/src/backends/torch/torchlib.cc b/src/backends/torch/torchlib.cc index 3334cdfa5..a594fff20 100644 --- a/src/backends/torch/torchlib.cc +++ b/src/backends/torch/torchlib.cc @@ -756,7 +756,7 @@ namespace dd ad_geometry.get("pad_mode").get()); } auto *img_ic = reinterpret_cast(&inputc); - NoiseParams noise_params; + NoiseParams noise_params(img_ic->_bw); noise_params._rgb = img_ic->_rgb; APIData ad_noise = ad_mllib.getobj("noise"); if (!ad_noise.empty()) @@ -764,7 +764,7 @@ namespace dd noise_params._prob = ad_noise.get("prob").get(); this->_logger->info("noise: {}", noise_params._prob); } - DistortParams distort_params; + DistortParams distort_params(img_ic->_bw); distort_params._rgb = img_ic->_rgb; APIData ad_distort = ad_mllib.getobj("distort"); if (!ad_distort.empty()) diff --git a/tests/ut-torchapi.cc b/tests/ut-torchapi.cc index 6c61c8844..68fea6583 100644 --- a/tests/ut-torchapi.cc +++ b/tests/ut-torchapi.cc @@ -1610,6 +1610,80 @@ TEST(torchapi, service_train_images_ctc_native) fileops::remove_dir(resnet18_ocr_train_repo + "test_0.lmdb"); } +TEST(torchapi, service_train_ctc_native_bw) +{ + // Just check that there are no errors when training in black&white + setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true); + torch::manual_seed(torch_seed); + at::globalContext().setDeterministicCuDNN(true); + + // Create service + JsonAPI japi; + std::string sname = "imgserv"; + std::string jstr + = "{\"mllib\":\"torch\",\"description\":\"image\",\"type\":" + "\"supervised\",\"model\":{\"repository\":\"" + + resnet18_ocr_train_repo + + "\"},\"parameters\":{\"input\":{\"connector\":\"image\"," + "\"width\":112,\"height\":32,\"bw\":true,\"db\":true,\"ctc\":true}," + "\"mllib\":{\"template\":\"crnn\",\"gpu\":true,\"timesteps\":128}}}"; + std::string joutstr = japi.jrender(japi.service_create(sname, jstr)); + ASSERT_EQ(created_str, joutstr); + + // Train (few iterations) + std::string jtrainstr + = "{\"service\":\"imgserv\",\"async\":false,\"parameters\":{" + "\"mllib\":{\"solver\":{\"iterations\":3,\"base_lr\":1e-4" + ",\"iter_size\":4,\"solver_type\":\"ADAM\",\"test_" + "interval\":200},\"net\":{\"batch_size\":32}," + "\"resume\":false,\"mirror\":false,\"rotate\":false," + "\"geometry\":{\"prob\":0.1,\"persp_horizontal\":" + "false,\"persp_vertical\":false,\"zoom_in\":true,\"zoom_out\":true," + "\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{" + "\"prob\":0.01},\"dataloader_threads\":4}," + "\"input\":{\"seed\":12345,\"db\":true,\"shuffle\":true}," + "\"output\":{\"measure\":[\"acc\"]}},\"data\":[\"" + + ocr_train_data + "\",\"" + ocr_test_data + "\"]}"; + joutstr = japi.jrender(japi.service_train(jtrainstr)); + JDoc jd; + std::cout << "joutstr=" << joutstr << std::endl; + jd.Parse(joutstr.c_str()); + ASSERT_TRUE(!jd.HasParseError()); + ASSERT_EQ(201, jd["status"]["code"]); + + // predict + std::string jpredictstr = "{\"service\":\"imgserv\",\"parameters\":{" + "\"output\":{\"best\":1,\"ctc\":true}}," + "\"data\":[\"" + + ocr_test_image + "\"]}"; + + joutstr = japi.jrender(japi.service_predict(jpredictstr)); + jd = JDoc(); + std::cout << "joutstr=" << joutstr << std::endl; + jd.Parse(joutstr.c_str()); + ASSERT_TRUE(!jd.HasParseError()); + ASSERT_EQ(200, jd["status"]["code"]); + + // remove files + std::unordered_set lfiles; + fileops::list_directory(resnet18_ocr_train_repo, true, false, false, lfiles); + ASSERT_TRUE( + fileops::file_exists(resnet18_ocr_train_repo + "checkpoint-3.npt")); + for (std::string ff : lfiles) + { + if (ff.find("checkpoint") != std::string::npos + || ff.find("solver") != std::string::npos) + remove(ff.c_str()); + } + ASSERT_TRUE( + !fileops::file_exists(resnet18_ocr_train_repo + "checkpoint-3.npt")); + + fileops::clear_directory(resnet18_ocr_train_repo + "train.lmdb"); + fileops::clear_directory(resnet18_ocr_train_repo + "test_0.lmdb"); + fileops::remove_dir(resnet18_ocr_train_repo + "train.lmdb"); + fileops::remove_dir(resnet18_ocr_train_repo + "test_0.lmdb"); +} + TEST(torchapi, service_publish_trained_model) { setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);