Skip to content

Commit

Permalink
fix: seeded random crops at test time
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz authored and Bycob committed Apr 28, 2022
1 parent 1ef2796 commit 92feae3
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 56 deletions.
21 changes: 15 additions & 6 deletions src/backends/torch/torchdataaug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ namespace dd
{
int crop_x = 0;
int crop_y = 0;
applyCrop(src, _crop_params, crop_x, crop_y);
applyCrop(src, _crop_params, crop_x, crop_y, true, true);
}

void
Expand Down Expand Up @@ -170,7 +170,7 @@ namespace dd
}
int crop_x = 0;
int crop_y = 0;
bool cropped = applyCrop(src, _crop_params, crop_x, crop_y);
bool cropped = applyCrop(src, _crop_params, crop_x, crop_y, true, true);
if (cropped)
{
applyCropBBox(bboxes, classes, _crop_params,
Expand Down Expand Up @@ -222,7 +222,7 @@ namespace dd
{
int crop_x = 0;
int crop_y = 0;
bool cropped = applyCrop(src, _crop_params, crop_x, crop_y);
bool cropped = applyCrop(src, _crop_params, crop_x, crop_y, true, true);
if (cropped)
applyCrop(tgt, _crop_params, crop_x, crop_y, false);
}
Expand Down Expand Up @@ -346,7 +346,8 @@ namespace dd
}

bool TorchImgRandAugCV::applyCrop(cv::Mat &src, CropParams &cp, int &crop_x,
int &crop_y, const bool &sample)
int &crop_y, const bool &sample,
const bool &test)
{
if (cp._crop_size <= 0)
return false;
Expand All @@ -355,8 +356,16 @@ namespace dd
{
#pragma omp critical
{
crop_x = cp._uniform_int_crop_x(_rnd_gen);
crop_y = cp._uniform_int_crop_y(_rnd_gen);
if (test)
{
crop_x = cp._uniform_int_crop_x(_rnd_test_gen);
crop_y = cp._uniform_int_crop_y(_rnd_test_gen);
}
else
{
crop_x = cp._uniform_int_crop_x(_rnd_gen);
crop_y = cp._uniform_int_crop_y(_rnd_gen);
}
}
}
cv::Rect crop(crop_x, crop_y, cp._crop_size, cp._crop_size);
Expand Down
17 changes: 14 additions & 3 deletions src/backends/torch/torchdataaug.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#pragma GCC diagnostic pop
#include <random>

#define DATAAUG_TEST_SEED 23124534

namespace dd
{
class ImgAugParams
Expand Down Expand Up @@ -79,6 +81,7 @@ namespace dd
int _crop_size = -1;
std::uniform_int_distribution<int> _uniform_int_crop_x;
std::uniform_int_distribution<int> _uniform_int_crop_y;
int _test_crop_samples = 1; /**< number of sampled crops (at test time). */
};

class CutoutParams : public ImgAugParams
Expand Down Expand Up @@ -280,12 +283,18 @@ namespace dd
_cutout_params._img_width = _crop_params._crop_size;
_cutout_params._img_height = _crop_params._crop_size;
}
reset_rnd_test_gen();
}

~TorchImgRandAugCV()
{
}

void reset_rnd_test_gen()
{
_rnd_test_gen = std::default_random_engine(DATAAUG_TEST_SEED);
}

void augment(cv::Mat &src);
void augment_with_bbox(cv::Mat &src, std::vector<torch::Tensor> &targets);
void augment_with_segmap(cv::Mat &src, cv::Mat &tgt);
Expand All @@ -305,7 +314,7 @@ namespace dd
const float &img_width, const float &img_height,
const int &rot);
bool applyCrop(cv::Mat &src, CropParams &cp, int &crop_x, int &crop_y,
const bool &sample = true);
const bool &sample = true, const bool &test = false);
void applyCropBBox(std::vector<std::vector<float>> &bboxes,
std::vector<int> &classes, const CropParams &cp,
const float &img_width, const float &img_height,
Expand Down Expand Up @@ -347,8 +356,8 @@ namespace dd
void applyDistortHue(cv::Mat &src);
void applyDistortOrderChannel(cv::Mat &src);

private:
// augmentation options & parameter
public:
// augmentation options & parameters
bool _mirror = false;
bool _rotate = false;

Expand All @@ -360,6 +369,8 @@ namespace dd

// random generators
std::default_random_engine _rnd_gen;
std::default_random_engine
_rnd_test_gen; /**< test time, seeded generator. */
std::uniform_real_distribution<float>
_uniform_real_1; /**< random real uniform between 0 and 1. */
std::bernoulli_distribution _bernouilli;
Expand Down
115 changes: 76 additions & 39 deletions src/backends/torch/torchdataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,9 @@ namespace dd
has_data = true;
}

// all data for one example
std::vector<torch::Tensor> d;
// all targets for one example
std::vector<torch::Tensor> t;

if (!_image)
Expand All @@ -522,6 +524,19 @@ namespace dd
std::stringstream targetstream(targets);
torch::load(d, datastream);
torch::load(t, targetstream);

for (unsigned int i = 0; i < d.size(); ++i)
{
while (i >= data.size())
data.emplace_back();
data.at(i).push_back(d[i]);
}
for (unsigned int i = 0; i < t.size(); ++i)
{
while (i >= target.size())
target.emplace_back();
target.at(i).push_back(t[i]);
}
}
else
{
Expand All @@ -533,52 +548,74 @@ namespace dd
inputc->_bw, inputc->width(),
inputc->height());

// data augmentation can apply here, with OpenCV
if (!_test)
{
if (_bbox)
_img_rand_aug_cv.augment_with_bbox(bgr, t);
else if (_segmentation)
_img_rand_aug_cv.augment_with_segmap(bgr, bw_target);
else
_img_rand_aug_cv.augment(bgr);
}
else
int samples = 1;

if (_test && _img_rand_aug_cv._crop_params._crop_size > 0)
samples = _img_rand_aug_cv._crop_params._test_crop_samples;

while (samples > 0)
{
// cropping requires test set 'augmentation'
if (_bbox)
cv::Mat bgr_sample = bgr.clone();
cv::Mat bw_target_sample;
std::vector<torch::Tensor> d_sample = d;
std::vector<torch::Tensor> t_sample = t;
if (_segmentation)
bw_target_sample = bw_target.clone();

// data augmentation can apply here, with OpenCV
if (!_test)
{
_img_rand_aug_cv.augment_test_with_bbox(bgr, t);
if (_bbox)
_img_rand_aug_cv.augment_with_bbox(bgr_sample,
t_sample);
else if (_segmentation)
_img_rand_aug_cv.augment_with_segmap(
bgr_sample, bw_target_sample);
else
_img_rand_aug_cv.augment(bgr_sample);
}
else if (_segmentation)
_img_rand_aug_cv.augment_test_with_segmap(bgr,
bw_target);
else
_img_rand_aug_cv.augment_test(bgr);
}
{
// cropping requires test set 'augmentation'
if (_bbox)
{
_img_rand_aug_cv.augment_test_with_bbox(bgr_sample,
t_sample);
}
else if (_segmentation)
_img_rand_aug_cv.augment_test_with_segmap(
bgr_sample, bw_target_sample);
else
_img_rand_aug_cv.augment_test(bgr_sample);
}

torch::Tensor imgt = image_to_tensor(bgr, bgr.rows, bgr.cols);
d.push_back(imgt);
torch::Tensor imgt = image_to_tensor(
bgr_sample, bgr_sample.rows, bgr_sample.cols);
d_sample.push_back(imgt);

if (_segmentation)
{
at::Tensor targett_seg = image_to_tensor(
bw_target, bw_target.rows, bw_target.cols, true);
t.push_back(targett_seg);
}
}
if (_segmentation)
{
at::Tensor targett_seg = image_to_tensor(
bw_target_sample, bw_target_sample.rows,
bw_target_sample.cols, true);
t_sample.push_back(targett_seg);
}

for (unsigned int i = 0; i < d.size(); ++i)
{
while (i >= data.size())
data.emplace_back();
data.at(i).push_back(d[i]);
}
for (unsigned int i = 0; i < t.size(); ++i)
{
while (i >= target.size())
target.emplace_back();
target.at(i).push_back(t[i]);
--samples;

for (unsigned int i = 0; i < d_sample.size(); ++i)
{
while (i >= data.size())
data.emplace_back();
data.at(i).push_back(d_sample[i]);
}
for (unsigned int i = 0; i < t_sample.size(); ++i)
{
while (i >= target.size())
target.emplace_back();
target.at(i).push_back(t_sample[i]);
}
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/backends/torch/torchdataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ namespace dd
bool _segmentation = false; /**< whether a segmentation dataset. */
std::vector<std::string> _dbFullNames;
std::vector<std::string> _datasets_names;
bool _test = false; /**< wheater a test set */
bool _test = false; /**< whether a test set */

protected:
bool _db = false;
Expand Down
6 changes: 6 additions & 0 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,9 @@ namespace dd
int crop_size = ad_mllib.get("crop_size").get<int>();
crop_params
= CropParams(crop_size, inputc.width(), inputc.height());
if (ad_mllib.has("test_crop_samples"))
crop_params._test_crop_samples
= ad_mllib.get("test_crop_samples").get<int>();
this->_logger->info("crop_size : {}", crop_size);
}
CutoutParams cutout_params;
Expand Down Expand Up @@ -1855,6 +1858,9 @@ namespace dd
APIData ad_out = ad.getobj("parameters").getobj("output");
int nclasses = _masked_lm ? inputc.vocab_size() : _nclasses;

// reset data aug test random generator
dataset._img_rand_aug_cv.reset_rnd_test_gen();

// confusion matrix is irrelevant to masked_lm training
if (_masked_lm && ad_out.has("measure"))
{
Expand Down
25 changes: 18 additions & 7 deletions tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -694,15 +694,23 @@ TEST(torchapi, service_train_images)
+ iterations_resnet50 + ",\"base_lr\":" + torch_lr
+ ",\"iter_size\":4,\"solver_type\":\"ADAM\",\"test_"
"interval\":200},\"net\":{\"batch_size\":4},"
"\"resume\":false,\"mirror\":true,\"rotate\":true,\"crop_size\":224,"
"\"cutout\":0.5,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
"\"resume\":false,\"mirror\":true,\"rotate\":true,\"crop_size\":"
"224,\"test_crop_samples\":10,"
"\"cutout\":0.5,\"geometry\":{"
"\"prob\":0.1,\"persp_horizontal\":"
"true,\"persp_vertical\":true,"
"\"zoom_in\":true,\"zoom_out\":"
"true,"
"\"pad_mode\":\"constant\"},"
"\"noise\":{\"prob\":0.01},"
"\"distort\":{"
"\"prob\":0."
"01},\"dataloader_threads\":4}"
","
"\"input\":{\"seed\":12345,\"db\":true,\"shuffle\":true},"
"\"output\":{\"measure\":[\"f1\",\"acc\"]}},\"data\":[\""
"\"input\":{\"seed\":12345,\"db\":"
"true,\"shuffle\":true},"
"\"output\":{\"measure\":[\"f1\","
"\"acc\"]}},\"data\":[\""
+ resnet50_train_data + "\",\"" + resnet50_test_data + "\"]}";
joutstr = japi.jrender(japi.service_train(jtrainstr));
JDoc jd;
Expand Down Expand Up @@ -909,7 +917,9 @@ TEST(torchapi, service_train_image_segmentation_deeplabv3)
+ iterations_deeplabv3 + ",\"base_lr\":" + torch_lr
+ ",\"iter_size\":1,\"solver_type\":\"ADAM\",\"test_"
"interval\":100},\"net\":{\"batch_size\":4},"
"\"resume\":false,\"mirror\":true,\"rotate\":true,\"crop_size\":224,"
"\"resume\":false,\"mirror\":true,\"rotate\":true,\"crop_size\":"
"224,"
"\"test_crop_samples\":10,"
"\"cutout\":0.5,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
Expand Down Expand Up @@ -1785,6 +1795,7 @@ TEST(torchapi, service_train_object_detection_yolox)
"type\":\"ADAM\",\"test_interval\":200},\"net\":{\"batch_size\":2,"
"\"test_batch_size\":2,\"reg_weight\":0.5},\"resume\":false,"
"\"mirror\":true,\"rotate\":true,\"crop_size\":512,"
"\"test_crop_samples\":10,"
"\"cutout\":0.1,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
Expand Down

0 comments on commit 92feae3

Please sign in to comment.