Skip to content

Commit

Permalink
feat: training segmentation models with torch backend
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz authored and mergify[bot] committed Nov 24, 2021
1 parent 9bda7f7 commit 1e3ff16
Show file tree
Hide file tree
Showing 9 changed files with 413 additions and 67 deletions.
162 changes: 131 additions & 31 deletions src/backends/torch/torchdataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,29 +118,56 @@ namespace dd
}
}

void
TorchDataset::write_image_to_db(const cv::Mat &bgr,
const std::vector<torch::Tensor> &target)
void TorchDataset::image_to_stringstream(const cv::Mat &img,
std::ostringstream &dstream)
{
// serialize image
std::stringstream dstream;
std::vector<uint8_t> buffer;
std::vector<int> param = { cv::IMWRITE_JPEG_QUALITY, 100 };
cv::imencode(".jpg", bgr, buffer, param);
cv::imencode(".jpg", img, buffer, param);
for (uint8_t c : buffer)
dstream << c;
}

void
TorchDataset::write_image_to_db(const cv::Mat &bgr,
const std::vector<torch::Tensor> &target)
{
// serialize image
std::ostringstream dstream;
image_to_stringstream(bgr, dstream);

// serialize target
std::ostringstream tstream;
torch::save(target, tstream);

write_image_to_db(dstream, tstream, bgr.rows, bgr.cols);
}

void TorchDataset::write_image_to_db(const cv::Mat &bgr,
const cv::Mat &bw_target)
{
// serialize image
std::ostringstream dstream;
image_to_stringstream(bgr, dstream);

// serialize target
std::ostringstream tstream;
image_to_stringstream(bw_target, tstream);

write_image_to_db(dstream, tstream, bgr.rows, bgr.cols);
}

void TorchDataset::write_image_to_db(const std::ostringstream &dstream,
const std::ostringstream &tstream,
const int &height, const int &width)
{
// check on db
if (_dbData == nullptr)
{
_dbData = std::shared_ptr<db::DB>(db::GetDB(_backend));
_dbData->Open(_dbFullName, db::NEW);
_txn = std::shared_ptr<db::Transaction>(_dbData->NewTransaction());
_logger->info("Preparing db of {}x{} images", bgr.cols, bgr.rows);
_logger->info("Preparing db of {}x{} images", width, height);
}

// data & target keys
Expand Down Expand Up @@ -173,8 +200,20 @@ namespace dd
bgr = cv::Mat(img_data, true);
bgr = cv::imdecode(bgr,
bw ? CV_LOAD_IMAGE_GRAYSCALE : CV_LOAD_IMAGE_COLOR);
std::stringstream targetstream(targets);
torch::load(targett, targetstream);
cv::Mat bw_target; // for segmentation only.

if (_segmentation)
{
std::vector<uint8_t> img_target_data(targets.begin(), targets.end());
bw_target = cv::Mat(img_target_data, true);
bw_target = cv::imdecode(bw_target, CV_LOAD_IMAGE_GRAYSCALE);
}
else
{
std::stringstream targetstream(targets);
torch::load(targett, targetstream);
}

if (bgr.cols != width || bgr.rows != height)
{
float w_ratio = static_cast<float>(width) / bgr.cols;
Expand All @@ -189,6 +228,19 @@ namespace dd
targett[0][bb][2] *= w_ratio;
targett[0][bb][3] *= h_ratio;
}

if (_segmentation)
{
cv::resize(bw_target, bw_target, cv::Size(width, height), 0, 0,
cv::INTER_NEAREST);
}
}

if (_segmentation)
{
at::Tensor targett_seg
= image_to_tensor(bw_target, height, width, true);
targett.push_back(targett_seg);
}
}

Expand All @@ -210,6 +262,25 @@ namespace dd
}
}

// add image batch
void TorchDataset::add_image_batch(const cv::Mat &bgr, const int &width,
const int &height,
const cv::Mat &bw_target)
{
if (!_db)
{
// to tensor
at::Tensor imgt = image_to_tensor(bgr, height, width);
at::Tensor imgt_tgt = image_to_tensor(bw_target, height, width, true);
add_batch({ imgt }, { imgt_tgt });
}
else
{
// write to db
write_image_to_db(bgr, bw_target);
}
}

void TorchDataset::add_batch(const std::vector<at::Tensor> &data,
const std::vector<at::Tensor> &target)
{
Expand Down Expand Up @@ -462,6 +533,10 @@ namespace dd
{
if (_bbox)
_img_rand_aug_cv.augment_with_bbox(bgr, t);
else if (_segmentation)
{
// TODO: augment for segmentation
}
else
_img_rand_aug_cv.augment(bgr);
}
Expand Down Expand Up @@ -552,13 +627,16 @@ namespace dd
}

/*-- image tools --*/
int TorchDataset::read_image_file(const std::string &fname, cv::Mat &out)
int TorchDataset::read_image_file(const std::string &fname, cv::Mat &out,
const bool &target)
{
ImgTorchInputFileConn *inputc
= reinterpret_cast<ImgTorchInputFileConn *>(_inputc);

DDImg dimg;
inputc->copy_parameters_to(dimg);
if (target) // used for segmentation masks
dimg._bw = true;

try
{
Expand Down Expand Up @@ -605,6 +683,22 @@ namespace dd
return add_image_file(fname, { target_to_tensor(target) }, height, width);
}

int TorchDataset::add_image_image_file(const std::string &fname,
const std::string &fname_target,
const int &height, const int &width)
{
cv::Mat img;
int res = read_image_file(fname, img);
if (res != 0)
return res;
cv::Mat img_tgt;
res = read_image_file(fname_target, img_tgt, true);
if (res != 0)
return res;
add_image_batch(img, height, width, img_tgt);
return res;
}

int TorchDataset::add_image_bbox_file(const std::string &fname,
const std::string &bboxfname,
const int &height, const int &width)
Expand Down Expand Up @@ -675,7 +769,8 @@ namespace dd
}

at::Tensor TorchDataset::image_to_tensor(const cv::Mat &bgr,
const int &height, const int &width)
const int &height, const int &width,
const bool &target)
{
ImgTorchInputFileConn *inputc
= reinterpret_cast<ImgTorchInputFileConn *>(_inputc);
Expand All @@ -687,36 +782,41 @@ namespace dd
imgt = imgt.toType(at::kFloat).permute({ 2, 0, 1 });
size_t nchannels = imgt.size(0);

if (!inputc->_supports_bw && nchannels == 1)
if (!target)
{
this->_logger->warn("Model needs 3 input channel, input will be "
"duplicated to fit the model input format");
imgt = imgt.repeat({ 3, 1, 1 });
nchannels = 3;
}
if (!inputc->_supports_bw && nchannels == 1)
{
this->_logger->warn("Model needs 3 input channel, input will be "
"duplicated to fit the model input format");
imgt = imgt.repeat({ 3, 1, 1 });
nchannels = 3;
}

if (inputc->_scale != 1.0)
imgt = imgt.mul(inputc->_scale);
if (inputc->_scale != 1.0)
imgt = imgt.mul(inputc->_scale);

if (!inputc->_mean.empty() && inputc->_mean.size() != nchannels)
throw InputConnectorBadParamException(
"mean vector be of size the number of channels ("
+ std::to_string(nchannels) + ")");
if (!inputc->_mean.empty() && inputc->_mean.size() != nchannels)
throw InputConnectorBadParamException(
"mean vector be of size the number of channels ("
+ std::to_string(nchannels) + ")");

for (size_t m = 0; m < inputc->_mean.size(); m++)
imgt[m] = imgt[m].sub_(inputc->_mean.at(m));
for (size_t m = 0; m < inputc->_mean.size(); m++)
imgt[m] = imgt[m].sub_(inputc->_mean.at(m));

if (!inputc->_std.empty() && inputc->_std.size() != nchannels)
throw InputConnectorBadParamException(
"std vector be of size the number of channels ("
+ std::to_string(nchannels) + ")");
if (!inputc->_std.empty() && inputc->_std.size() != nchannels)
throw InputConnectorBadParamException(
"std vector be of size the number of channels ("
+ std::to_string(nchannels) + ")");

for (size_t s = 0; s < inputc->_std.size(); s++)
imgt[s] = imgt[s].div_(inputc->_std.at(s));
for (size_t s = 0; s < inputc->_std.size(); s++)
imgt[s] = imgt[s].div_(inputc->_std.at(s));
}

return imgt;
}

// TODO: segmentation target image to tensor

at::Tensor TorchDataset::target_to_tensor(const int &target)
{
at::Tensor targett{ torch::full(1, target, torch::kLong) };
Expand Down
55 changes: 48 additions & 7 deletions src/backends/torch/torchdataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ namespace dd
bool _classification = true; /**< whether a classification dataset. */

bool _image = false; /**< whether an image dataset. */
bool _bbox = false; /**< true if bbox detection dataset */
bool _bbox = false; /**< true if bbox detection dataset. */
bool _segmentation = false; /**< true if segmentation dataset. */
bool _test = false; /**< whether a test set */
TorchImgRandAugCV _img_rand_aug_cv; /**< image data augmentation policy. */

Expand All @@ -106,7 +107,8 @@ namespace dd
_indices(d._indices), _lfiles(d._lfiles), _batches(d._batches),
_dbFullName(d._dbFullName), _inputc(d._inputc),
_classification(d._classification), _image(d._image), _bbox(d._bbox),
_test(d._test), _img_rand_aug_cv(d._img_rand_aug_cv)
_segmentation(d._segmentation), _test(d._test),
_img_rand_aug_cv(d._img_rand_aug_cv)
{
}

Expand All @@ -129,6 +131,9 @@ namespace dd
const int &height,
const std::vector<at::Tensor> &targett);

void add_image_batch(const cv::Mat &bgr, const int &width,
const int &height, const cv::Mat &bw_target);

/**
* \brief reset dataset reading status : ie start new epoch
*/
Expand Down Expand Up @@ -273,7 +278,8 @@ namespace dd
}

/*-- image tools --*/
int read_image_file(const std::string &fname, cv::Mat &out);
int read_image_file(const std::string &fname, cv::Mat &out,
const bool &target = false);

int add_image_file(const std::string &fname,
const std::vector<at::Tensor> &target,
Expand All @@ -296,6 +302,15 @@ namespace dd
const std::vector<double> &target, const int &height,
const int &width);

/**
* \brief adds image from image filename, with an image as target
* \param width of preprocessed image
* \param height of preprocessed image
*/
int add_image_image_file(const std::string &fname,
const std::string &fname_target,
const int &height, const int &width);

/**
* \brief adds image to batch, with a bbox list file as target.
* \param width of preprocessed image
Expand All @@ -307,9 +322,13 @@ namespace dd

/**
* \brief turns an image into a torch::Tensor
* \param bgr input image
* \param height image height
* \param width image width
* \param target whether the image is a label/target
*/
at::Tensor image_to_tensor(const cv::Mat &bgr, const int &height,
const int &width);
const int &width, const bool &target = false);

/**
* \brief turns an int into a torch::Tensor
Expand All @@ -328,12 +347,31 @@ namespace dd
void write_tensors_to_db(const std::vector<at::Tensor> &data,
const std::vector<at::Tensor> &target);

/**
* \brief converts and image to a serialized string
*/
void image_to_stringstream(const cv::Mat &img,
std::ostringstream &dstream);

/**
* \brief writes encoded image to db with a tensor target
*/
void write_image_to_db(const cv::Mat &bgr,
const std::vector<torch::Tensor> &target);

/**
* \brief writes encoded images to db, one as input, the other as target
*/
void write_image_to_db(const cv::Mat &bgr, const cv::Mat &bw_target);

/**
* \brief write two stringstreams to db, as key and value.
* width and height are for logging purposes
*/
void write_image_to_db(const std::ostringstream &dstream,
const std::ostringstream &tstream,
const int &height, const int &width);

/**
* \brief reads an encoded image from db along with its tensor target
*/
Expand All @@ -360,9 +398,10 @@ namespace dd
*/
TorchMultipleDataset(const TorchMultipleDataset &d)
: _inputc(d._inputc), _image(d._image), _bbox(d._bbox),
_classification(d._classification), _dbFullNames(d._dbFullNames),
_datasets_names(d._datasets_names), _test(d._test), _db(d._db),
_backend(d._backend), _dbPrefix(d._dbPrefix), _logger(d._logger),
_classification(d._classification), _segmentation(d._segmentation),
_dbFullNames(d._dbFullNames), _datasets_names(d._datasets_names),
_test(d._test), _db(d._db), _backend(d._backend),
_dbPrefix(d._dbPrefix), _logger(d._logger),
_batches_per_transaction(d._batches_per_transaction),
_datasets(d._datasets)
{
Expand Down Expand Up @@ -519,6 +558,7 @@ namespace dd
_datasets[id]._inputc = _inputc;
_datasets[id]._image = _image;
_datasets[id]._bbox = _bbox;
_datasets[id]._segmentation = _segmentation;
_datasets[id]._test = _test;
_datasets[id]._classification = _classification;
_datasets[id].set_db_params(_db, _backend,
Expand All @@ -533,6 +573,7 @@ namespace dd
bool _image = false; /**< whether an image dataset. */
bool _bbox = false; /**< true if bbox detection dataset */
bool _classification = true; /**< whether a classification dataset. */
bool _segmentation = false; /**< whether a segmentation dataset. */
std::vector<std::string> _dbFullNames;
std::vector<std::string> _datasets_names;
bool _test = false; /**< wheater a test set */
Expand Down
Loading

0 comments on commit 1e3ff16

Please sign in to comment.