Skip to content

Commit

Permalink
fix: segmentation with torch backend + full cropping support
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz authored and mergify[bot] committed Jan 12, 2022
1 parent 486ff30 commit e14c3f2
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 21 deletions.
16 changes: 16 additions & 0 deletions src/backends/torch/torchdataaug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ namespace dd
applyDistort(src);
}

void TorchImgRandAugCV::augment_test(cv::Mat &src)
{
int crop_x = 0;
int crop_y = 0;
applyCrop(src, _crop_params, crop_x, crop_y);
}

void
TorchImgRandAugCV::augment_with_bbox(cv::Mat &src,
std::vector<torch::Tensor> &targets)
Expand Down Expand Up @@ -138,6 +145,15 @@ namespace dd
applyDistort(src);
}

void TorchImgRandAugCV::augment_test_with_segmap(cv::Mat &src, cv::Mat &tgt)
{
int crop_x = 0;
int crop_y = 0;
bool cropped = applyCrop(src, _crop_params, crop_x, crop_y);
if (cropped)
applyCrop(tgt, _crop_params, crop_x, crop_y, false);
}

bool TorchImgRandAugCV::roll_weighted_dice(const float &prob)
{
// Draw random between 0 and 1
Expand Down
19 changes: 16 additions & 3 deletions src/backends/torch/torchdataaug.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,29 @@ namespace dd
GeometryParams(const float &prob, const bool &geometry_persp_horizontal,
const bool &geometry_persp_vertical,
const bool &geometry_zoom_out, const bool &geometry_zoom_in,
const int &geometry_pad_mode)
const std::string &geometry_pad_mode_str)
: _prob(prob), _geometry_persp_horizontal(geometry_persp_horizontal),
_geometry_persp_vertical(geometry_persp_vertical),
_geometry_zoom_out(geometry_zoom_out),
_geometry_zoom_in(geometry_zoom_in),
_geometry_pad_mode(geometry_pad_mode)
_geometry_zoom_in(geometry_zoom_in)
{
set_pad_mode(geometry_pad_mode_str);
}

~GeometryParams()
{
}

void set_pad_mode(const std::string &geometry_pad_mode_str)
{
if (geometry_pad_mode_str == "constant")
_geometry_pad_mode = 1;
else if (geometry_pad_mode_str == "mirrored")
_geometry_pad_mode = 2;
else if (geometry_pad_mode_str == "repeat_nearest")
_geometry_pad_mode = 3;
}

float _prob = 0.0;
bool _geometry_persp_horizontal
= true; /**< horizontal perspective change. */
Expand Down Expand Up @@ -275,6 +285,9 @@ namespace dd
void augment_with_bbox(cv::Mat &src, std::vector<torch::Tensor> &targets);
void augment_with_segmap(cv::Mat &src, cv::Mat &tgt);

void augment_test(cv::Mat &src);
void augment_test_with_segmap(cv::Mat &src, cv::Mat &tgt);

protected:
bool roll_weighted_dice(const float &prob);
bool applyMirror(cv::Mat &src, const bool &sample = true);
Expand Down
27 changes: 21 additions & 6 deletions src/backends/torch/torchdataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,11 @@ namespace dd
{
// serialize image
std::ostringstream dstream;
image_to_stringstream(bgr, dstream, true);
image_to_stringstream(bgr, dstream, false);

// serialize target
std::ostringstream tstream;
image_to_stringstream(bw_target, tstream, false);
image_to_stringstream(bw_target, tstream, true);

write_image_to_db(dstream, tstream, bgr.rows, bgr.cols);
}
Expand Down Expand Up @@ -539,11 +539,22 @@ namespace dd
if (_bbox)
_img_rand_aug_cv.augment_with_bbox(bgr, t);
else if (_segmentation)
_img_rand_aug_cv.augment_with_segmap(bgr, bw_target);
else
_img_rand_aug_cv.augment(bgr);
}
else
{
// cropping requires test set 'augmentation'
if (_bbox)
{
_img_rand_aug_cv.augment_with_segmap(bgr, bw_target);
// no cropping yet with bboxes
}
if (_segmentation)
_img_rand_aug_cv.augment_test_with_segmap(bgr,
bw_target);
else
_img_rand_aug_cv.augment(bgr);
_img_rand_aug_cv.augment_test(bgr);
}

torch::Tensor imgt = image_to_tensor(bgr, bgr.rows, bgr.cols);
Expand Down Expand Up @@ -646,7 +657,10 @@ namespace dd
DDImg dimg;
inputc->copy_parameters_to(dimg);
if (target) // used for segmentation masks
dimg._bw = true;
{
dimg._bw = true;
dimg._interp = "nearest";
}

try
{
Expand Down Expand Up @@ -703,9 +717,10 @@ namespace dd
return res;
cv::Mat img_tgt;
res = read_image_file(fname_target, img_tgt, true);

if (res != 0)
return res;
add_image_batch(img, height, width, img_tgt);
add_image_batch(img, width, height, img_tgt);
return res;
}

Expand Down
9 changes: 9 additions & 0 deletions src/backends/torch/torchdataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,15 @@ namespace dd
_batches_per_transaction = tsize;
}

/**
* \brief sets image data augmenter across test datasets
*/
void set_img_rand_aug_cv(const TorchImgRandAugCV &tiracv)
{
for (size_t i = 0; i < _datasets.size(); ++i)
_datasets.at(i)._img_rand_aug_cv = tiracv;
}

/**
* \brief commits final db transactions
*/
Expand Down
13 changes: 9 additions & 4 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,8 @@ namespace dd
{
bool has_data_augmentation
= ad_mllib.has("mirror") || ad_mllib.has("rotate")
|| ad_mllib.has("crop_size") || ad_mllib.has("cutout");
|| ad_mllib.has("crop_size") || ad_mllib.has("cutout")
|| ad_mllib.has("geometry");
if (has_data_augmentation)
{
bool has_mirror
Expand Down Expand Up @@ -670,8 +671,8 @@ namespace dd
geometry_params._geometry_zoom_in
= ad_geometry.get("zoom_in").get<bool>();
if (ad_geometry.has("pad_mode"))
geometry_params._geometry_pad_mode
= ad_geometry.get("pad_mode").get<int>();
geometry_params.set_pad_mode(
ad_geometry.get("pad_mode").get<std::string>());
}
auto *img_ic = reinterpret_cast<ImgTorchInputFileConn *>(&inputc);
NoiseParams noise_params;
Expand All @@ -693,6 +694,10 @@ namespace dd
inputc._dataset._img_rand_aug_cv = TorchImgRandAugCV(
has_mirror, has_rotate, crop_params, cutout_params,
geometry_params, noise_params, distort_params);
inputc._test_datasets.set_img_rand_aug_cv(TorchImgRandAugCV(
has_mirror, has_rotate, crop_params, cutout_params,
geometry_params, noise_params,
distort_params)); // only uses cropping if enable
}
}
int dataloader_threads = 1;
Expand Down Expand Up @@ -1983,7 +1988,7 @@ namespace dd
output = torch::softmax(output, 1);
torch::Tensor target = batch.target.at(0).to(torch::kFloat64);
torch::Tensor segmap
= torch::flatten(torch::argmax(output.squeeze(), 1))
= torch::flatten(torch::argmax(output.squeeze(), 0))
.contiguous()
.to(torch::kFloat64)
.to(cpu); // squeeze removes the batch size
Expand Down
4 changes: 3 additions & 1 deletion src/backends/torch/torchloss.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ namespace dd
if (_loss.empty())
{
loss = torch::nn::functional::cross_entropy(
y_pred, y.squeeze(1).to(torch::kLong)); // TODO: options
y_pred, y.squeeze(1).to(torch::kLong),
torch::nn::functional::CrossEntropyFuncOptions().weight(
_class_weights));
}
else if (_loss == "dice" || _loss == "dice_multiclass"
|| _loss == "dice_weighted" || _loss == "dice_weighted_batch"
Expand Down
9 changes: 9 additions & 0 deletions src/backends/torch/torchutils.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ namespace dd
namespace torch_utils
{

inline void cerr_tensor_shape(const std::string &tname,
const torch::Tensor t)
{
std::cerr << tname << " shape=";
for (auto d : t.sizes())
std::cerr << d << " ";
std::cerr << std::endl;
}

/**
* \brief empty cuda caching allocator, This should be called after every
* job that allocates a model. Pytorch keeps cuda memory allocated even
Expand Down
18 changes: 12 additions & 6 deletions tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,8 @@ TEST(torchapi, service_train_images)
"\"resume\":false,\"mirror\":true,\"rotate\":true,\"crop_size\":224,"
"\"cutout\":0.5,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":1},\"noise\":{\"prob\":0.01},\"distort\":{\"prob\":0."
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
"\"prob\":0."
"01},\"dataloader_threads\":4}"
","
"\"input\":{\"seed\":12345,\"db\":true,\"shuffle\":true},"
Expand Down Expand Up @@ -778,7 +779,8 @@ TEST(torchapi, service_train_image_segmentation_deeplabv3)
"\"resume\":false,\"mirror\":true,\"rotate\":true,\"crop_size\":224,"
"\"cutout\":0.5,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":1},\"noise\":{\"prob\":0.01},\"distort\":{\"prob\":0."
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
"\"prob\":0."
"01}},"
"\"input\":{\"seed\":12345,\"db\":true,\"shuffle\":true,"
"\"segmentation\":true,\"scale\":0.0039,\"mean\":[0.485,0.456,0.406]"
Expand Down Expand Up @@ -863,7 +865,8 @@ TEST(torchapi, service_train_image_segmentation_deeplabv3_dice)
"\"resume\":false,\"mirror\":true,\"rotate\":true,\"crop_size\":224,"
"\"cutout\":0.5,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":1},\"noise\":{\"prob\":0.01},\"distort\":{\"prob\":0."
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
"\"prob\":0."
"01}},"
"\"input\":{\"seed\":12345,\"db\":true,\"shuffle\":true,"
"\"segmentation\":true,\"scale\":0.0039,\"mean\":[0.485,0.456,0.406]"
Expand Down Expand Up @@ -946,7 +949,8 @@ TEST(torchapi, service_train_image_segmentation_segformer)
"\"resume\":false,\"mirror\":true,\"rotate\":true,\"crop_size\":224,"
"\"cutout\":0.5,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":1},\"noise\":{\"prob\":0.01},\"distort\":{\"prob\":0."
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
"\"prob\":0."
"01}},"
"\"input\":{\"seed\":12345,\"db\":true,\"shuffle\":true,"
"\"segmentation\":true,\"scale\":0.0039,\"mean\":[0.485,0.456,0.406]"
Expand Down Expand Up @@ -1562,7 +1566,8 @@ TEST(torchapi, service_train_object_detection_fasterrcnn)
"true,\"crop_size\":224,"
"\"cutout\":0.1,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":1},\"noise\":{\"prob\":0.01},\"distort\":{\"prob\":0."
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
"\"prob\":0."
"01}},\"input\":{\"seed\":12347,"
"\"db\":true,\"shuffle\":true},\"output\":{\"measure\":[\"map\"]}},"
"\"data\":[\""
Expand Down Expand Up @@ -1633,7 +1638,8 @@ TEST(torchapi, service_train_object_detection_yolox)
"\"mirror\":true,\"rotate\":true,\"crop_size\":640,"
"\"cutout\":0.1,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":1},\"noise\":{\"prob\":0.01},\"distort\":{\"prob\":0."
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
"\"prob\":0."
"01}},\"input\":{\"seed\":12347,\"db\":true,"
"\"shuffle\":true},\"output\":{\"measure\":[\"map\"]}},\"data\":[\""
+ fasterrcnn_train_data + "\",\"" + fasterrcnn_test_data + "\"]}";
Expand Down
6 changes: 5 additions & 1 deletion tools/torch/trace_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import torch
import torchvision
import torchvision.models as M
from torchvision.models.segmentation.deeplabv3 import DeepLabHead

parser = argparse.ArgumentParser(description="Trace image processing models from torchvision")
parser.add_argument('models', type=str, nargs='*', help="Models to trace.")
Expand Down Expand Up @@ -290,12 +291,15 @@ def get_detection_input(batch_size=1, img_width=224, img_height=224):

else:
kwargs = {}
if args.num_classes:
if args.num_classes and not segmentation:
logging.info("Using num_classes = %d" % args.num_classes)
kwargs["num_classes"] = args.num_classes

model = model_classes[mname](pretrained=args.pretrained, progress=args.verbose, **kwargs)

if segmentation and 'deeplabv3' in mname:
model.classifier = DeepLabHead(2048, args.num_classes)

if args.to_dd_native:
# Make model NativeModuleWrapper compliant
model = Wrapper(model)
Expand Down

0 comments on commit e14c3f2

Please sign in to comment.