Skip to content

Commit

Permalink
fix: no resize when training with images
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed May 15, 2023
1 parent 20d8ebe commit e84c616
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 71 deletions.
29 changes: 19 additions & 10 deletions src/backends/torch/torchdataaug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,17 +395,24 @@ namespace dd

if (sample)
{
int img_width = src.cols;
int img_height = src.rows;
std::uniform_int_distribution<int> uniform_int_crop_x(
0, img_width - cp._crop_size);
std::uniform_int_distribution<int> uniform_int_crop_y(
0, img_height - cp._crop_size);

#pragma omp critical
{
if (test)
{
crop_x = cp._uniform_int_crop_x(_rnd_test_gen);
crop_y = cp._uniform_int_crop_y(_rnd_test_gen);
crop_x = uniform_int_crop_x(_rnd_test_gen);
crop_y = uniform_int_crop_y(_rnd_test_gen);
}
else
{
crop_x = cp._uniform_int_crop_x(_rnd_gen);
crop_y = cp._uniform_int_crop_y(_rnd_gen);
crop_x = uniform_int_crop_x(_rnd_gen);
crop_y = uniform_int_crop_y(_rnd_gen);
}
}
}
Expand Down Expand Up @@ -464,20 +471,22 @@ namespace dd

#pragma omp critical
{
int img_width = src.cols;
int img_height = src.rows;
// get shape and area to erase
int w = 0, h = 0, rect_x = 0, rect_y = 0;
if (cp._w == 0 && cp._h == 0)
{
float s = cp._uniform_real_cutout_s(_rnd_gen) * cp._img_width
* cp._img_height; // area
float s = cp._uniform_real_cutout_s(_rnd_gen) * img_width
* img_height; // area
float r = cp._uniform_real_cutout_r(_rnd_gen); // aspect ratio

w = std::min(cp._img_width,
w = std::min(img_width,
static_cast<int>(std::floor(std::sqrt(s / r))));
h = std::min(cp._img_height,
h = std::min(img_height,
static_cast<int>(std::floor(std::sqrt(s * r))));
std::uniform_int_distribution<int> distx(0, cp._img_width - w);
std::uniform_int_distribution<int> disty(0, cp._img_height - h);
std::uniform_int_distribution<int> distx(0, img_width - w);
std::uniform_int_distribution<int> disty(0, img_height - h);
rect_x = distx(_rnd_gen);
rect_y = disty(_rnd_gen);
}
Expand Down
50 changes: 6 additions & 44 deletions src/backends/torch/torchdataaug.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,67 +33,34 @@

namespace dd
{
class ImgAugParams
class CropParams
{
public:
ImgAugParams() : _img_width(224), _img_height(224)
CropParams()
{
}

ImgAugParams(const int &img_width, const int &img_height)
: _img_width(img_width), _img_height(img_height)
CropParams(const int &crop_size) : _crop_size(crop_size)
{
}

~ImgAugParams()
{
}

int _img_width;
int _img_height;
};

class CropParams : public ImgAugParams
{
public:
CropParams() : ImgAugParams()
{
}

CropParams(const int &crop_size, const int &img_width,
const int &img_height)
: ImgAugParams(img_width, img_height), _crop_size(crop_size)
{
if (_crop_size > 0)
{
_uniform_int_crop_x
= std::uniform_int_distribution<int>(0, _img_width - _crop_size);
_uniform_int_crop_y = std::uniform_int_distribution<int>(
0, _img_height - _crop_size);
}
}

~CropParams()
{
}

// default params
int _crop_size = -1;
std::uniform_int_distribution<int> _uniform_int_crop_x;
std::uniform_int_distribution<int> _uniform_int_crop_y;
int _test_crop_samples = 1; /**< number of sampled crops (at test time). */
};

class CutoutParams : public ImgAugParams
class CutoutParams
{
public:
CutoutParams() : ImgAugParams()
CutoutParams()
{
}

CutoutParams(const float &prob, const int &img_width,
const int &img_height)
: ImgAugParams(img_width, img_height), _prob(prob)
CutoutParams(const float &prob) : _prob(prob)
{
_uniform_real_cutout_s
= std::uniform_real_distribution<float>(_cutout_sl, _cutout_sh);
Expand Down Expand Up @@ -287,11 +254,6 @@ namespace dd
_uniform_real_1(0.0, 1.0), _bernouilli(0.5),
_uniform_int_rotate(0, 3)
{
if (_crop_params._crop_size > 0)
{
_cutout_params._img_width = _crop_params._crop_size;
_cutout_params._img_height = _crop_params._crop_size;
}
reset_rnd_test_gen();
}

Expand Down
13 changes: 8 additions & 5 deletions src/backends/torch/torchdataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ namespace dd
torch::load(targett, targetstream);
}

if (bgr.cols != width || bgr.rows != height)
if (width > 0 && height > 0 && (bgr.cols != width || bgr.rows != height))
{
cv::resize(bgr, bgr, cv::Size(width, height), 0, 0, cv::INTER_CUBIC);

Expand Down Expand Up @@ -860,10 +860,13 @@ namespace dd

std::ifstream infile(bboxfname);
std::string line;
double wfactor = static_cast<double>(inputc->_width)
/ static_cast<double>(orig_width);
double hfactor = static_cast<double>(inputc->_height)
/ static_cast<double>(orig_height);
double wfactor = inputc->_width > 0 ? static_cast<double>(inputc->_width)
/ static_cast<double>(orig_width)
: 1;
double hfactor = inputc->_height > 0
? static_cast<double>(inputc->_height)
/ static_cast<double>(orig_height)
: 1;

while (std::getline(infile, line))
{
Expand Down
18 changes: 10 additions & 8 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -701,8 +701,7 @@ namespace dd
if (ad_mllib.has("crop_size"))
{
int crop_size = ad_mllib.get("crop_size").get<int>();
crop_params
= CropParams(crop_size, inputc.width(), inputc.height());
crop_params = CropParams(crop_size);
if (ad_mllib.has("test_crop_samples"))
crop_params._test_crop_samples
= ad_mllib.get("test_crop_samples").get<int>();
Expand All @@ -712,8 +711,7 @@ namespace dd
if (ad_mllib.has("cutout"))
{
float cutout = ad_mllib.get("cutout").get<double>();
cutout_params
= CutoutParams(cutout, inputc.width(), inputc.height());
cutout_params = CutoutParams(cutout);
this->_logger->info("cutout: {}", cutout);
}
GeometryParams geometry_params;
Expand Down Expand Up @@ -1640,6 +1638,10 @@ namespace dd
throw MLLibInternalException(
"Couldn't find original image size for " + uri);
}
int src_width
= inputc.width() > 0 ? inputc.width() : cols - 1;
int src_height
= inputc.height() > 0 ? inputc.height() : rows - 1;

APIData results_ad;
std::vector<double> probs;
Expand Down Expand Up @@ -1676,10 +1678,10 @@ namespace dd
this->_mlmodel.get_hcorresp(labels_acc[j]));

double bbox[] = {
bboxes_acc[j][0] / inputc.width() * (cols - 1),
bboxes_acc[j][1] / inputc.height() * (rows - 1),
bboxes_acc[j][2] / inputc.width() * (cols - 1),
bboxes_acc[j][3] / inputc.height() * (rows - 1),
bboxes_acc[j][0] / src_width * (cols - 1),
bboxes_acc[j][1] / src_height * (rows - 1),
bboxes_acc[j][2] / src_width * (cols - 1),
bboxes_acc[j][3] / src_height * (rows - 1),
};

// clamp bbox
Expand Down
8 changes: 4 additions & 4 deletions src/imginputfileconn.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ namespace dd
{
if (_scaled)
scale(src, dst);
else if (_width == 0 || _height == 0)
else if (_width < 0 || _height < 0)
{
if (_width == 0 && _height == 0)
if (_width < 0 && _height < 0)
{
// Do nothing and keep native resolution. May cause issues if
// batched images are different resolutions
Expand Down Expand Up @@ -199,9 +199,9 @@ namespace dd
{
if (_scaled)
scale_cuda(src, dst);
else if (_width == 0 || _height == 0)
else if (_width < 0 || _height < 0)
{
if (_width == 0 && _height == 0)
if (_width < 0 && _height < 0)
{
// Do nothing and keep native resolution. May cause issues if
// batched images are different resolutions
Expand Down
141 changes: 141 additions & 0 deletions tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,46 @@ TEST(torchapi, service_predict_object_detection)
ASSERT_EQ(preds_best.Size(), 3);
}

TEST(torchapi, service_predict_object_detection_any_size)
{
JsonAPI japi;
std::string sname = "detectserv";
std::string jstr
= "{\"mllib\":\"torch\",\"description\":\"fasterrcnn\",\"type\":"
"\"supervised\",\"model\":{\"repository\":\""
+ detect_repo
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":"
"-1,\"width\":-1,\"rgb\":true,\"scale\":0.0039},\"mllib\":{"
"\"template\":\"fasterrcnn\"}}}";

std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
ASSERT_EQ(created_str, joutstr);
std::string jpredictstr
= "{\"service\":\"detectserv\",\"parameters\":{"
"\"input\":{\"height\":-1,"
"\"width\":-1},\"output\":{\"bbox\":true, "
"\"best_bbox\":1,\"confidence_threshold\":0.8}},\"data\":[\""
+ detect_train_repo_fasterrcnn + "/imgs/000550-L.jpg\"]}";

joutstr = japi.jrender(japi.service_predict(jpredictstr));
JDoc jd;
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(200, jd["status"]["code"]);
ASSERT_TRUE(jd["body"]["predictions"].IsArray());

auto &preds = jd["body"]["predictions"][0]["classes"];
std::string cl1 = preds[0]["cat"].GetString();
ASSERT_TRUE(cl1 == "car");
ASSERT_TRUE(preds[0]["prob"].GetDouble() > 0.9);
auto &bbox = preds[0]["bbox"];
ASSERT_NEAR(bbox["xmin"].GetDouble(), 258.0, 5.0);
ASSERT_NEAR(bbox["ymin"].GetDouble(), 333.0, 5.0);
ASSERT_NEAR(bbox["xmax"].GetDouble(), 401.0, 5.0);
ASSERT_NEAR(bbox["ymax"].GetDouble(), 448.0, 5.0);
}

TEST(torchapi, service_predict_segmentation)
{
JsonAPI japi;
Expand Down Expand Up @@ -2748,6 +2788,107 @@ TEST(torchapi, service_train_object_detection_translation)
fileops::remove_dir(detect_train_repo_yolox + "test_0.lmdb");
}

TEST(torchapi, service_train_object_detection_yolox_any_size)
{
// Test with arbitrary image size: width = -1, height = -1
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
torch::manual_seed(torch_seed);
at::globalContext().setDeterministicCuDNN(true);

JsonAPI japi;
std::string sname = "detectserv";
std::string jstr
= "{\"mllib\":\"torch\",\"description\":\"yolox\",\"type\":"
"\"supervised\",\"model\":{\"repository\":\""
+ detect_train_repo_yolox
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":"
"-1,\"width\":-1,\"rgb\":true,\"bbox\":true,\"db\":true},"
"\"mllib\":{\"template\":\"yolox\",\"gpu\":true,"
"\"nclasses\":2}}}";

std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
ASSERT_EQ(created_str, joutstr);

// Train
std::string jtrainstr
= "{\"service\":\"detectserv\",\"async\":false,\"parameters\":{"
"\"mllib\":{\"solver\":{\"iterations\":3"
+ std::string("")
//+ iterations_detection + ",\"base_lr\":" + torch_lr
+ ",\"iter_size\":2,\"solver_"
"type\":\"ADAM\",\"test_interval\":200},\"net\":{\"batch_size\":2,"
"\"test_batch_size\":1,\"reg_weight\":0.5},\"resume\":false,"
"\"mirror\":true,\"rotate\":true,\"crop_size\":512,"
"\"test_crop_samples\":10,"
"\"cutout\":0.1,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
"\"prob\":0.01}},\"input\":{\"seed\":12347,\"db\":true,"
"\"shuffle\":true},\"output\":{\"measure\":[\"map-05\",\"map-50\","
"\"map-90\"]}},\"data\":[\""
+ fasterrcnn_train_data + "\",\"" + fasterrcnn_test_data + "\"]}";

joutstr = japi.jrender(japi.service_train(jtrainstr));
JDoc jd;
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(201, jd["status"]["code"]);

// ASSERT_EQ(jd["body"]["measure"]["iteration"], 200) << "iterations";
ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() <= 1.0) << "map";
ASSERT_TRUE(jd["body"]["measure"]["map-05"].GetDouble() <= 1.0) << "map-05";
ASSERT_TRUE(jd["body"]["measure"]["map-50"].GetDouble() <= 1.0) << "map-50";
ASSERT_TRUE(jd["body"]["measure"]["map-90"].GetDouble() <= 1.0) << "map-90";
// ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() > 0.0) << "map";

// check metrics
auto &meas = jd["body"]["measure"];
ASSERT_TRUE(meas.HasMember("iou_loss"));
ASSERT_TRUE(meas.HasMember("conf_loss"));
ASSERT_TRUE(meas.HasMember("cls_loss"));
ASSERT_TRUE(meas.HasMember("l1_loss"));
ASSERT_TRUE(meas.HasMember("train_loss"));
ASSERT_TRUE(
std::abs(meas["train_loss"].GetDouble()
- (meas["iou_loss"].GetDouble() * 0.5
+ meas["cls_loss"].GetDouble() + meas["l1_loss"].GetDouble()
+ meas["conf_loss"].GetDouble()))
< 0.0001);

// check that predict works fine
std::string jpredictstr = "{\"service\":\"detectserv\",\"parameters\":{"
"\"input\":{\"height\":-1,"
"\"width\":-1},\"output\":{\"bbox\":true, "
"\"confidence_threshold\":0.8}},\"data\":[\""
+ detect_train_repo_fasterrcnn
+ "/imgs/000550-L.jpg\"]}";
joutstr = japi.jrender(japi.service_predict(jpredictstr));
jd = JDoc();
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(200, jd["status"]["code"]);

std::unordered_set<std::string> lfiles;
fileops::list_directory(detect_train_repo_yolox, true, false, false, lfiles);
for (std::string ff : lfiles)
{
if (ff.find("checkpoint") != std::string::npos
|| ff.find("solver") != std::string::npos)
remove(ff.c_str());
}
ASSERT_TRUE(!fileops::file_exists(detect_train_repo_yolox + "checkpoint-"
+ iterations_detection + ".ptw"));
ASSERT_TRUE(!fileops::file_exists(detect_train_repo_yolox + "checkpoint-"
+ iterations_detection + ".pt"));

fileops::clear_directory(detect_train_repo_yolox + "train.lmdb");
fileops::clear_directory(detect_train_repo_yolox + "test_0.lmdb");
fileops::remove_dir(detect_train_repo_yolox + "train.lmdb");
fileops::remove_dir(detect_train_repo_yolox + "test_0.lmdb");
}

TEST(torchapi, service_train_images_native)
{
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
Expand Down

0 comments on commit e84c616

Please sign in to comment.