From 4e74172d7760bd8016402d6f730d2a21ef0b5469 Mon Sep 17 00:00:00 2001
From: Peter Jin <pjin@nvidia.com>
Date: Thu, 22 Jun 2017 16:35:34 -0700
Subject: [PATCH 1/2] Preliminary implementation of data augmentation for
 variable-sized batch elements.

Random isotropic resizing draws a random shape. Requires a discrete
uniform distribution.

Implement random and center cropping pre-transform. Make sure that
`DataReader::sample` gives a pre-transformed Datum.

For pre-transform image resizing, use cubic interpolation when upsampling,
and use a pyramid of linear interpolations followed by cubic interpolation
when downsampling.

Bugfix: update the shape during the downsampling pyramid.

When resizing, avoid extra copies and use proper rounding.

Enable pre-transforms for non-encoded Datum.

Move the pre-transform data augmentation code to DataLayer.

Remove whitespace for cleaner diff.

Moving most of variable-sized image data augmentation to
DataTransformer.

Refactoring the shape inference.

Add missing guards and remove unnecessary includes.

Fix a documentation typo.

Clear the label when doing variable-sized transform.

Add consts.

Rename some functions. Shorten the new prototxt names for readability.
Add const to functions. Add assert messages.

Shorten some lines to pass lint.

Add tests for variable sized image transforms.

Omit debug checks.

Use cv::Mat references where possible. Try to better explain that the
variable sized transforms are a sequence or pipeline of transforms.
Variable sized transform tests reuse more code. Other short fixes.
---
 include/caffe/data_transformer.hpp       |  62 +++++
 include/caffe/util/io.hpp                |   1 +
 include/caffe/util/math_functions.hpp    |   3 +
 src/caffe/data_transformer.cpp           | 328 ++++++++++++++++++++---
 src/caffe/layers/data_layer.cpp          |  34 ++-
 src/caffe/proto/caffe.proto              |  30 +++
 src/caffe/test/test_data_transformer.cpp |  55 ++++
 src/caffe/util/io.cpp                    |  72 +++--
 src/caffe/util/math_functions.cpp        |  16 ++
 9 files changed, 540 insertions(+), 61 deletions(-)

diff --git a/include/caffe/data_transformer.hpp b/include/caffe/data_transformer.hpp
index 2608af9fb58..db19a79d3c8 100644
--- a/include/caffe/data_transformer.hpp
+++ b/include/caffe/data_transformer.hpp
@@ -1,6 +1,12 @@
 #ifndef CAFFE_DATA_TRANSFORMER_HPP
 #define CAFFE_DATA_TRANSFORMER_HPP
 
+#ifdef USE_OPENCV
+
+#include <opencv2/core/core.hpp>
+
+#endif  // USE_OPENCV
+
 #include <string>
 #include <vector>
 
@@ -49,6 +55,44 @@ class DataTransformer {
   void CopyPtrEntry(shared_ptr<Datum> datum, Dtype* transformed_ptr, size_t& out_sizeof_element,
       bool output_labels, Dtype* label);
 
+  /**
+   * @brief Whether there are any "variable_sized" transformations defined
+   * in the data layer's transform_param block.
+   */
+  bool var_sized_transforms_enabled() const;
+
+  /**
+   * @brief Calculate the final shape from applying the "variable_sized"
+   * transformations defined in the data layer's transform_param block
+   * on the provided image, without actually performing any transformations.
+   *
+   * @param orig_shape
+   *    The shape of the data to be transformed.
+   */
+  vector<int> var_sized_transforms_shape(const vector<int>& orig_shape) const;
+
+  /**
+   * @brief Applies "variable_sized" transformations defined in the data layer's
+   * transform_param block to the data.
+   *
+   * @param old_datum
+   *    The source Datum containing data of arbitrary shape.
+   * @param new_datum
+   *    The destination Datum that will store transformed data of a fixed
+   *    shape. Suitable for other transformations.
+   */
+  shared_ptr<Datum> VariableSizedTransforms(const Datum& old_datum);
+
+  bool        var_sized_image_random_resize_enabled() const;
+  vector<int> var_sized_image_random_resize_shape(const vector<int>& prev_shape) const;
+  cv::Mat&    var_sized_image_random_resize(cv::Mat& img);
+  bool        var_sized_image_random_crop_enabled() const;
+  vector<int> var_sized_image_random_crop_shape(const vector<int>& prev_shape) const;
+  cv::Mat&    var_sized_image_random_crop(const cv::Mat& img);
+  bool        var_sized_image_center_crop_enabled() const;
+  vector<int> var_sized_image_center_crop_shape(const vector<int>& prev_shape) const;
+  cv::Mat&    var_sized_image_center_crop(const cv::Mat& img);
+
   /**
    * @brief Applies the transformation defined in the data layer's
    * transform_param block to the data.
@@ -137,6 +181,20 @@ class DataTransformer {
       const std::array<unsigned int, 3>& rand);
 #endif  // USE_OPENCV
 
+  vector<int> InferDatumShape(const Datum& datum);
+#ifdef USE_OPENCV
+  vector<int> InferCVMatShape(const cv::Mat& img);
+#endif  // USE_OPENCV
+
+  /**
+   * @brief Infers the shape of transformed_blob will have when
+   *    the transformation is applied to the data.
+   *
+   * @param bottom_shape
+   *    The shape of the data to be transformed.
+   */
+  vector<int> InferBlobShape(const vector<int>& bottom_shape, bool use_gpu = false);
+
   /**
    * @brief Infers the shape of transformed_blob will have when
    *    the transformation is applied to the data.
@@ -180,6 +238,10 @@ class DataTransformer {
 #ifndef CPU_ONLY
   GPUMemory::Workspace mean_values_gpu_;
 #endif
+  shared_ptr<Datum> var_sized_transform_datum_;
+  cv::Mat tmp_rand_resize_img_;
+  cv::Mat tmp_rand_crop_img_;
+  cv::Mat tmp_center_crop_img_;
 };
 
 }  // namespace caffe
diff --git a/include/caffe/util/io.hpp b/include/caffe/util/io.hpp
index 1a599883ca3..4c9834bcfed 100644
--- a/include/caffe/util/io.hpp
+++ b/include/caffe/util/io.hpp
@@ -144,6 +144,7 @@ cv::Mat ReadImageToCVMat(const string& filename);
 cv::Mat DecodeDatumToCVMatNative(const Datum& datum);
 cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color);
 
+cv::Mat DatumToCVMat(const Datum& datum);
 void CVMatToDatum(const cv::Mat& cv_img, Datum* datum);
 #endif  // USE_OPENCV
 
diff --git a/include/caffe/util/math_functions.hpp b/include/caffe/util/math_functions.hpp
index 24235ee9f7b..d77b664e2c6 100644
--- a/include/caffe/util/math_functions.hpp
+++ b/include/caffe/util/math_functions.hpp
@@ -102,6 +102,9 @@ Dtype caffe_nextafter(const Dtype b);
 template <typename Dtype>
 void caffe_rng_uniform(int n, float a, float b, Dtype* r);
 
+template <>
+void caffe_rng_uniform<int>(int n, float a, float b, int* r);
+
 template <typename Dtype>
 void caffe_rng_gaussian(int n, float mu, float sigma, Dtype* r);
 
diff --git a/src/caffe/data_transformer.cpp b/src/caffe/data_transformer.cpp
index 4ff07b3133b..6ed75f02fcd 100644
--- a/src/caffe/data_transformer.cpp
+++ b/src/caffe/data_transformer.cpp
@@ -1,6 +1,7 @@
 #ifdef USE_OPENCV
 
 #include <opencv2/core/core.hpp>
+#include <opencv2/imgproc/imgproc.hpp>
 
 #endif  // USE_OPENCV
 
@@ -36,6 +37,7 @@ DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param, Ph
       mean_values_.push_back(param_.mean_value(c));
     }
   }
+  var_sized_transform_datum_ = make_shared<Datum>();
 }
 
 #ifdef USE_OPENCV
@@ -136,6 +138,248 @@ void DataTransformer<Dtype>::Fill3Randoms(unsigned int *rand) const {
   }
 }
 
+template<typename Dtype>
+bool DataTransformer<Dtype>::var_sized_transforms_enabled() const {
+  return var_sized_image_random_resize_enabled() ||
+      var_sized_image_random_crop_enabled() ||
+      var_sized_image_center_crop_enabled();
+}
+
+template<typename Dtype>
+vector<int> DataTransformer<Dtype>::var_sized_transforms_shape(
+    const vector<int>& orig_shape) const {
+  CHECK_EQ(orig_shape.size(), 4);
+  // All of the transforms (random resize, random crop, center crop)
+  // can be enabled, and they operate sequentially, one after the other.
+  vector<int> shape(orig_shape);
+  if (var_sized_image_random_resize_enabled()) {
+    shape = var_sized_image_random_resize_shape(shape);
+  }
+  if (var_sized_image_random_crop_enabled()) {
+    shape = var_sized_image_random_crop_shape(shape);
+  }
+  if (var_sized_image_center_crop_enabled()) {
+    shape = var_sized_image_center_crop_shape(shape);
+  }
+  CHECK_NE(shape[2], 0)
+      << "variable sized transform has invalid output height; did you forget to crop?";
+  CHECK_NE(shape[3], 0)
+      << "variable sized transform has invalid output width; did you forget to crop?";
+  return shape;
+}
+
+template<typename Dtype>
+shared_ptr<Datum> DataTransformer<Dtype>::VariableSizedTransforms(const Datum& old_datum) {
+  cv::Mat orig_img;
+  if (old_datum.encoded()) {
+    CHECK(!(param_.force_color() && param_.force_gray()))
+        << "cannot set both force_color and force_gray";
+    if (param_.force_color() || param_.force_gray()) {
+      // If force_color then decode in color otherwise decode in gray.
+      orig_img = DecodeDatumToCVMat(old_datum, param_.force_color());
+    } else {
+      orig_img = DecodeDatumToCVMatNative(old_datum);
+    }
+  } else {
+    orig_img = DatumToCVMat(old_datum);
+  }
+  cv::Mat& img = orig_img;
+  if (var_sized_image_random_resize_enabled()) {
+    img = var_sized_image_random_resize(img);
+  }
+  if (var_sized_image_random_crop_enabled()) {
+    img = var_sized_image_random_crop(img);
+  }
+  if (var_sized_image_center_crop_enabled()) {
+    img = var_sized_image_center_crop(img);
+  }
+  {
+    Datum* new_datum = var_sized_transform_datum_.get();
+    CVMatToDatum(img, new_datum);
+    if (old_datum.has_label()) {
+      new_datum->set_label(old_datum.label());
+    } else {
+      new_datum->clear_label();
+    }
+    new_datum->set_record_id(old_datum.record_id());
+  }
+  return var_sized_transform_datum_;
+}
+
+template<typename Dtype>
+bool DataTransformer<Dtype>::var_sized_image_random_resize_enabled() const {
+  const int resize_lower = param_.var_sz_img_rand_resize_lower();
+  const int resize_upper = param_.var_sz_img_rand_resize_upper();
+  if (resize_lower == 0 && resize_upper == 0) {
+    return false;
+  } else if (resize_lower != 0 && resize_upper != 0) {
+    return true;
+  }
+  LOG(FATAL)
+      << "random resize 'lower' and 'upper' parameters must either "
+      "both be zero or both be nonzero";
+}
+
+template<typename Dtype>
+vector<int> DataTransformer<Dtype>::var_sized_image_random_resize_shape(
+    const vector<int>& prev_shape) const {
+  CHECK(var_sized_image_random_resize_enabled())
+      << "var sized transform must be enabled";
+  CHECK_EQ(prev_shape.size(), 4)
+      << "input shape should always have 4 axes (NCHW)";
+  vector<int> shape(4);
+  shape[0] = 1;
+  shape[1] = prev_shape[1];
+  // The output of a random resize is itself a variable sized image.
+  // By itself a random resize cannot produce an image that is valid input for
+  // downstream transformations, and must instead be terminated by a
+  // variable-sized crop (either random or center).
+  shape[2] = 0;
+  shape[3] = 0;
+  return shape;
+}
+
+template<typename Dtype>
+cv::Mat& DataTransformer<Dtype>::var_sized_image_random_resize(cv::Mat& img) {
+  const int resize_lower = param_.var_sz_img_rand_resize_lower();
+  const int resize_upper = param_.var_sz_img_rand_resize_upper();
+  CHECK_GT(resize_lower, 0)
+      << "random resize lower bound parameter must be positive";
+  CHECK_GT(resize_upper, 0)
+      << "random resize lower bound parameter must be positive";
+  int resize_size = -1;
+  caffe_rng_uniform(
+      1,
+      static_cast<float>(resize_lower), static_cast<float>(resize_upper),
+      &resize_size);
+  CHECK_NE(resize_size, -1)
+      << "uniform random sampling inexplicably failed";
+  const int img_height = img.rows;
+  const int img_width = img.cols;
+  const double scale = (img_width >= img_height) ?
+      ((static_cast<double>(resize_size)) / (static_cast<double>(img_height))) :
+      ((static_cast<double>(resize_size)) / (static_cast<double>(img_width)));
+  const int resize_height = static_cast<int>(std::round(scale * static_cast<double>(img_height)));
+  const int resize_width = static_cast<int>(std::round(scale * static_cast<double>(img_width)));
+  if (resize_height < img_height || resize_width < img_width) {
+    // Downsample with pixel area relation interpolation.
+    CHECK_LE(resize_height, img_height)
+        << "cannot downsample width without downsampling height";
+    CHECK_LE(resize_width, img_width)
+        << "cannot downsample height without downsampling width";
+    cv::resize(
+        img, tmp_rand_resize_img_,
+        cv::Size(resize_width, resize_height),
+        0.0, 0.0,
+        cv::INTER_AREA);
+    return tmp_rand_resize_img_;
+  } else if (resize_height > img_height || resize_width > img_width) {
+    // Upsample with cubic interpolation.
+    CHECK_GE(resize_height, img_height)
+        << "cannot upsample width without upsampling height";
+    CHECK_GE(resize_width, img_width)
+        << "cannot upsample height without upsampling width";
+    cv::resize(
+        img, tmp_rand_resize_img_,
+        cv::Size(resize_width, resize_height),
+        0.0, 0.0,
+        cv::INTER_CUBIC);
+    return tmp_rand_resize_img_;
+  } else if (resize_height == img_height && resize_width == img_width) {
+    return img;
+  }
+  LOG(FATAL)
+      << "unreachable random resize shape: ("
+      << img_width << ", " << img_height << ") => ("
+      << resize_width << ", " << resize_height << ")";
+}
+
+template<typename Dtype>
+bool DataTransformer<Dtype>::var_sized_image_random_crop_enabled() const {
+  const int crop_size = param_.var_sz_img_rand_crop();
+  return crop_size != 0;
+}
+
+template<typename Dtype>
+vector<int> DataTransformer<Dtype>::var_sized_image_random_crop_shape(
+    const vector<int>& prev_shape) const {
+  CHECK(var_sized_image_random_crop_enabled())
+      << "var sized transform must be enabled";
+  const int crop_size = param_.var_sz_img_rand_crop();
+  CHECK_EQ(prev_shape.size(), 4)
+      << "input shape should always have 4 axes (NCHW)";
+  vector<int> shape(4);
+  shape[0] = 1;
+  shape[1] = prev_shape[1];
+  shape[2] = crop_size;
+  shape[3] = crop_size;
+  return shape;
+}
+
+template<typename Dtype>
+cv::Mat& DataTransformer<Dtype>::var_sized_image_random_crop(const cv::Mat& img) {
+  const int crop_size = param_.var_sz_img_rand_crop();
+  CHECK_GT(crop_size, 0)
+      << "random crop size parameter must be positive";
+  const int img_height = img.rows;
+  const int img_width = img.cols;
+  CHECK_GE(img_height, crop_size)
+      << "crop size parameter must be at least as large as the image height";
+  CHECK_GE(img_width, crop_size)
+      << "crop size parameter must be at least as large as the image width";
+  int crop_offset_h = -1;
+  int crop_offset_w = -1;
+  caffe_rng_uniform(1, 0.0f, static_cast<float>(img_height - crop_size), &crop_offset_h);
+  caffe_rng_uniform(1, 0.0f, static_cast<float>(img_width - crop_size), &crop_offset_w);
+  CHECK_NE(crop_offset_h, -1)
+      << "uniform random sampling inexplicably failed";
+  CHECK_NE(crop_offset_w, -1)
+      << "uniform random sampling inexplicably failed";
+  cv::Rect crop_roi(crop_offset_w, crop_offset_h, crop_size, crop_size);
+  tmp_rand_crop_img_ = img(crop_roi);
+  return tmp_rand_crop_img_;
+}
+
+template<typename Dtype>
+bool DataTransformer<Dtype>::var_sized_image_center_crop_enabled() const {
+  const int crop_size = param_.var_sz_img_center_crop();
+  return crop_size != 0;
+}
+
+template<typename Dtype>
+vector<int> DataTransformer<Dtype>::var_sized_image_center_crop_shape(
+    const vector<int>& prev_shape) const {
+  CHECK(var_sized_image_center_crop_enabled())
+      << "var sized transform must be enabled";
+  const int crop_size = param_.var_sz_img_center_crop();
+  CHECK_EQ(prev_shape.size(), 4)
+      << "input shape should always have 4 axes (NCHW)";
+  vector<int> shape(4);
+  shape[0] = 1;
+  shape[1] = prev_shape[1];
+  shape[2] = crop_size;
+  shape[3] = crop_size;
+  return shape;
+}
+
+template<typename Dtype>
+cv::Mat& DataTransformer<Dtype>::var_sized_image_center_crop(const cv::Mat& img) {
+  const int crop_size = param_.var_sz_img_center_crop();
+  CHECK_GT(crop_size, 0)
+      << "center crop size parameter must be positive";
+  const int img_height = img.rows;
+  const int img_width = img.cols;
+  CHECK_GE(img_height, crop_size)
+      << "crop size parameter must be at least as large as the image height";
+  CHECK_GE(img_width, crop_size)
+      << "crop size parameter must be at least as large as the image width";
+  const int crop_offset_h = (img_height - crop_size) / 2;
+  const int crop_offset_w = (img_width - crop_size) / 2;
+  cv::Rect crop_roi(crop_offset_w, crop_offset_h, crop_size, crop_size);
+  tmp_center_crop_img_ = img(crop_roi);
+  return tmp_center_crop_img_;
+}
+
 #ifndef CPU_ONLY
 
 template<typename Dtype>
@@ -657,7 +901,7 @@ void DataTransformer<Dtype>::TransformPtr(const cv::Mat& cv_img,
 #endif  // USE_OPENCV
 
 template<typename Dtype>
-vector<int> DataTransformer<Dtype>::InferBlobShape(const Datum& datum, bool use_gpu) {
+vector<int> DataTransformer<Dtype>::InferDatumShape(const Datum& datum) {
   if (datum.encoded()) {
 #ifdef USE_OPENCV
     CHECK(!(param_.force_color() && param_.force_gray()))
@@ -669,59 +913,77 @@ vector<int> DataTransformer<Dtype>::InferBlobShape(const Datum& datum, bool use_
     } else {
       cv_img = DecodeDatumToCVMatNative(datum);
     }
-    // InferBlobShape using the cv::image.
-    return InferBlobShape(cv_img, use_gpu);
+    // Infer shape using the cv::image.
+    return InferCVMatShape(cv_img);
 #else
     LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";
 #endif  // USE_OPENCV
   }
-  const int crop_size = param_.crop_size();
   const int datum_channels = datum.channels();
   const int datum_height = datum.height();
   const int datum_width = datum.width();
-  // Check dimensions.
-  CHECK_GT(datum_channels, 0);
-  CHECK_GE(datum_height, crop_size);
-  CHECK_GE(datum_width, crop_size);
-  // Build BlobShape.
-  vector<int> shape(4);
-  shape[0] = 1;
-  shape[1] = datum_channels;
-  // if using GPU transform, don't crop
-  if (use_gpu) {
-    shape[2] = datum_height;
-    shape[3] = datum_width;
-  } else {
-    shape[2] = (crop_size) ? crop_size : datum_height;
-    shape[3] = (crop_size) ? crop_size : datum_width;
-  }
-  return shape;
+  vector<int> datum_shape(4);
+  datum_shape[0] = 1;
+  datum_shape[1] = datum_channels;
+  datum_shape[2] = datum_height;
+  datum_shape[3] = datum_width;
+  return datum_shape;
 }
 
 #ifdef USE_OPENCV
 
 template<typename Dtype>
-vector<int> DataTransformer<Dtype>::InferBlobShape(const cv::Mat& cv_img, bool use_gpu) {
-  const int crop_size = param_.crop_size();
+vector<int> DataTransformer<Dtype>::InferCVMatShape(const cv::Mat& cv_img) {
   const int img_channels = cv_img.channels();
   const int img_height = cv_img.rows;
   const int img_width = cv_img.cols;
-  // Check dimensions.
-  CHECK_GT(img_channels, 0);
-  CHECK_GE(img_height, crop_size);
-  CHECK_GE(img_width, crop_size);
-  // Build BlobShape.
   vector<int> shape(4);
   shape[0] = 1;
   shape[1] = img_channels;
+  shape[2] = img_height;
+  shape[3] = img_width;
+  return shape;
+}
+
+#endif  // USE_OPENCV
+
+template<typename Dtype>
+vector<int> DataTransformer<Dtype>::InferBlobShape(const vector<int>& bottom_shape, bool use_gpu) {
+  const int crop_size = param_.crop_size();
+  CHECK_EQ(bottom_shape.size(), 4);
+  CHECK_EQ(bottom_shape[0], 1);
+  const int bottom_channels = bottom_shape[1];
+  const int bottom_height = bottom_shape[2];
+  const int bottom_width = bottom_shape[3];
+  // Check dimensions.
+  CHECK_GT(bottom_channels, 0);
+  CHECK_GE(bottom_height, crop_size);
+  CHECK_GE(bottom_width, crop_size);
+  // Build BlobShape.
+  vector<int> top_shape(4);
+  top_shape[0] = 1;
+  top_shape[1] = bottom_channels;
+  // if using GPU transform, don't crop
   if (use_gpu) {
-    shape[2] = img_height;
-    shape[3] = img_width;
+    top_shape[2] = bottom_height;
+    top_shape[3] = bottom_width;
   } else {
-    shape[2] = (crop_size) ? crop_size : img_height;
-    shape[3] = (crop_size) ? crop_size : img_width;
+    top_shape[2] = (crop_size) ? crop_size : bottom_height;
+    top_shape[3] = (crop_size) ? crop_size : bottom_width;
   }
-  return shape;
+  return top_shape;
+}
+
+template<typename Dtype>
+vector<int> DataTransformer<Dtype>::InferBlobShape(const Datum& datum, bool use_gpu) {
+  return InferBlobShape(InferDatumShape(datum), use_gpu);
+}
+
+#ifdef USE_OPENCV
+
+template<typename Dtype>
+vector<int> DataTransformer<Dtype>::InferBlobShape(const cv::Mat& cv_img, bool use_gpu) {
+  return InferBlobShape(InferCVMatShape(cv_img), use_gpu);
 }
 
 #endif  // USE_OPENCV
diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp
index 784608aba39..ec9d45477b0 100644
--- a/src/caffe/layers/data_layer.cpp
+++ b/src/caffe/layers/data_layer.cpp
@@ -1,9 +1,3 @@
-#ifdef USE_OPENCV
-
-#include <opencv2/core/core.hpp>
-
-#endif  // USE_OPENCV
-
 #include "caffe/data_transformer.hpp"
 #include "caffe/layer.hpp"
 #include "caffe/layers/data_layer.hpp"
@@ -177,10 +171,17 @@ DataLayer<Ftype, Btype>::DataLayerSetUp(const vector<Blob*>& bottom, const vecto
   shared_ptr<Datum> sample_datum = sample_only_ ? sample_reader_->sample() : reader_->sample();
   init_offsets();
 
+  // Calculate the variable sized transformed datum shape.
+  vector<int> sample_datum_shape = this->data_transformers_[0]->InferDatumShape(*sample_datum);
+  if (this->data_transformers_[0]->var_sized_transforms_enabled()) {
+    sample_datum_shape =
+        this->data_transformers_[0]->var_sized_transforms_shape(sample_datum_shape);
+  }
+
   // Reshape top[0] and prefetch_data according to the batch_size.
   // Note: all these reshapings here in load_batch are needed only in case of
   // different datum shapes coming from database.
-  vector<int> top_shape = this->data_transformers_[0]->InferBlobShape(*sample_datum);
+  vector<int> top_shape = this->data_transformers_[0]->InferBlobShape(sample_datum_shape);
   top_shape[0] = batch_size;
   top[0]->Reshape(top_shape);
 
@@ -227,14 +228,20 @@ void DataLayer<Ftype, Btype>::load_batch(Batch<Ftype>* batch, int thread_id, siz
   shared_ptr<Datum> datum = reader->full_peek(qid);
   CHECK(datum);
 
+  // Calculate the variable sized transformed datum shape.
+  vector<int> datum_shape = this->data_transformers_[thread_id]->InferDatumShape(*datum);
+  if (this->data_transformers_[thread_id]->var_sized_transforms_enabled()) {
+    datum_shape = this->data_transformers_[thread_id]->var_sized_transforms_shape(datum_shape);
+  }
+
   // Use data_transformer to infer the expected blob shape from datum.
-  vector<int> top_shape = this->data_transformers_[thread_id]->InferBlobShape(*datum,
+  vector<int> top_shape = this->data_transformers_[thread_id]->InferBlobShape(datum_shape,
       use_gpu_transform);
   // Reshape batch according to the batch_size.
   top_shape[0] = batch_size;
   batch->data_.Reshape(top_shape);
   if (use_gpu_transform) {
-    top_shape = this->data_transformers_[thread_id]->InferBlobShape(*datum, false);
+    top_shape = this->data_transformers_[thread_id]->InferBlobShape(datum_shape, false);
     top_shape[0] = batch_size;
     batch->gpu_transformed_data_->Reshape(top_shape);
   }
@@ -262,7 +269,12 @@ void DataLayer<Ftype, Btype>::load_batch(Batch<Ftype>* batch, int thread_id, siz
   size_t current_batch_id = 0UL;
   size_t item_id;
   for (size_t entry = 0; entry < batch_size; ++entry) {
-    datum = reader->full_pop(qid, "Waiting for datum");
+    shared_ptr<Datum> pop_datum = reader->full_pop(qid, "Waiting for datum");
+    datum = pop_datum;
+    // Apply variable-sized transforms.
+    if (this->data_transformers_[thread_id]->var_sized_transforms_enabled()) {
+      datum = this->data_transformers_[thread_id]->VariableSizedTransforms(*datum);
+    }
     item_id = datum->record_id() % batch_size;
     if (datum->channels() > 0) {
       CHECK_EQ(top_shape[1], datum->channels())
@@ -304,7 +316,7 @@ void DataLayer<Ftype, Btype>::load_batch(Batch<Ftype>* batch, int thread_id, siz
       this->data_transformers_[thread_id]->TransformPtrEntry(datum, ptr, rand, this->output_labels_,
           label_ptr);
     }
-    reader->free_push(qid, datum);
+    reader->free_push(qid, pop_datum);
   }
 
   if (use_gpu_transform) {
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index dda9d614183..111fbeddb10 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -480,6 +480,36 @@ message LayerParameter {
 // Message that stores parameters used to apply transformation
 // to the data layer's data
 message TransformationParameter {
+  // When the images in a batch are of different shapes, we need to preprocess
+  // them into the same fixed shape, as downstream operations in caffe require
+  // images within a batch to be of the same shape.
+  //
+  // To transform one image of arbitrary shape into an image of fixed shape,
+  // we allow specifying a sequence of "variable-sized image transforms."
+  // There are three possible transforms, and it is possible for _all of them_
+  // to be enabled at the same time. They are always applied in the same order:
+  // (1) first random resize, (2) then random crop, (3) finally center crop.
+  // The last transform must be either a random crop or a center crop.
+  //
+  // The three supported transforms are as follows:
+  //
+  // 1. Random resize. This takes two parameters, "lower" and "upper," or
+  //    "L" and "U" for short. If the original image has shape (oldW, oldH),
+  //    the shorter side, D = min(oldW, oldH), is calculated. Then a resize
+  //    target size R is chosing uniformly from the interval [L, U], and both
+  //    sides of the original image are resized by a scaling factor R/D to yield
+  //    a new image with shape (R/D * oldW, R/D * oldH).
+  //
+  // 2. Random crop. This takes one crop parameter. A square region is randomly
+  //    chosen from the image for cropping.
+  //
+  // 3. Center crop. This takes one crop parameter. A square region is chosen
+  //    from the center of the image for cropping.
+  //
+  optional uint32 var_sz_img_rand_resize_lower = 10 [default = 0];
+  optional uint32 var_sz_img_rand_resize_upper = 11 [default = 0];
+  optional uint32 var_sz_img_rand_crop = 12 [default = 0];
+  optional uint32 var_sz_img_center_crop = 13 [default = 0];
   // For data pre-processing, we can do simple scaling and subtracting the
   // data mean, if provided. Note that the mean subtraction is always carried
   // out before scaling.
diff --git a/src/caffe/test/test_data_transformer.cpp b/src/caffe/test/test_data_transformer.cpp
index c1791f0aee8..02afaa8b2db 100644
--- a/src/caffe/test/test_data_transformer.cpp
+++ b/src/caffe/test/test_data_transformer.cpp
@@ -341,6 +341,61 @@ TYPED_TEST(DataTransformTest, TestMeanFile) {
   }
 }
 
+template <typename Dtype>
+class VarSzTransformsTest : public ::testing::Test {
+ protected:
+  VarSzTransformsTest()
+    : seed_(1701) {}
+
+  void Run(
+      const TransformationParameter transform_param,
+      const int expected_height, const int expected_width) {
+    const bool unique_pixels = false;  // pixels are equal to label
+    const int label = 42;
+    const int channels = 3;
+    const int height = 4;
+    const int width = 5;
+
+    Datum datum;
+    FillDatum(label, channels, height, width, unique_pixels, &datum);
+    DataTransformer<Dtype> transformer(transform_param, TEST);
+    Caffe::set_random_seed(seed_);
+    transformer.InitRand();
+    shared_ptr<Datum> transformed_datum = transformer.VariableSizedTransforms(datum);
+    EXPECT_EQ(transformed_datum->channels(), 3);
+    EXPECT_EQ(transformed_datum->height(), expected_height);
+    EXPECT_EQ(transformed_datum->width(), expected_width);
+    const int data_count = transformed_datum->data().size();
+    const char* data = &transformed_datum->data().at(0);
+    for (int j = 0; j < data_count; ++j) {
+      EXPECT_EQ(static_cast<int>(data[j]), label);
+    }
+  }
+
+  int seed_;
+};
+
+TYPED_TEST_CASE(VarSzTransformsTest, TestDtypesNoFP16);
+
+TYPED_TEST(VarSzTransformsTest, TestVarSzImgRandomResize) {
+  TransformationParameter transform_param;
+  transform_param.set_var_sz_img_rand_resize_lower(2);
+  transform_param.set_var_sz_img_rand_resize_upper(2);
+  this->Run(transform_param, 2, 3);
+}
+
+TYPED_TEST(VarSzTransformsTest, TestVarSzImgRandomCrop) {
+  TransformationParameter transform_param;
+  transform_param.set_var_sz_img_rand_crop(3);
+  this->Run(transform_param, 3, 3);
+}
+
+TYPED_TEST(VarSzTransformsTest, TestVarSzImgCenterCrop) {
+  TransformationParameter transform_param;
+  transform_param.set_var_sz_img_center_crop(3);
+  this->Run(transform_param, 3, 3);
+}
+
 #ifndef CPU_ONLY
 // GPU-based transform tests
 template <typename Dtype>
diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp
index 2fb9ea24a95..212d87a2d05 100644
--- a/src/caffe/util/io.cpp
+++ b/src/caffe/util/io.cpp
@@ -210,30 +210,68 @@ bool DecodeDatum(Datum* datum, bool is_color) {
   }
 }
 
+cv::Mat DatumToCVMat(const Datum& datum) {
+  if (datum.encoded()) {
+    LOG(FATAL) << "Datum encoded";
+  }
+  const int datum_channels = datum.channels();
+  const int datum_height = datum.height();
+  const int datum_width = datum.width();
+  const int datum_size = datum_channels * datum_height * datum_width;
+  CHECK_GT(datum_channels, 0);
+  CHECK_GT(datum_height, 0);
+  CHECK_GT(datum_width, 0);
+  cv::Mat img = cv::Mat::zeros(cv::Size(datum_width, datum_height), CV_8UC(datum_channels));
+  CHECK_EQ(img.channels(), datum_channels);
+  CHECK_EQ(img.rows, datum_height);
+  CHECK_EQ(img.cols, datum_width);
+  const std::string& datum_buf = datum.data();
+  CHECK_EQ(datum_buf.size(), datum_size);
+  const int datum_hw_stride = datum_height * datum_width;
+  for (int h = 0; h < datum_height; ++h) {
+    const int datum_h_offset = h * datum_width;
+    uchar* img_row_ptr = img.ptr<uchar>(h);
+    int img_row_index = 0;
+    for (int w = 0; w < datum_width; ++w) {
+      int datum_index = datum_h_offset + w;
+      for (int c = 0; c < datum_channels; ++c, datum_index += datum_hw_stride) {
+        img_row_ptr[img_row_index++] = datum_buf[datum_index];
+      }
+    }
+  }
+  return img;
+}
+
 void CVMatToDatum(const cv::Mat& cv_img, Datum* datum) {
   CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";
-  datum->set_channels(cv_img.channels());
-  datum->set_height(cv_img.rows);
-  datum->set_width(cv_img.cols);
-  datum->clear_data();
+  const int img_channels = cv_img.channels();
+  const int img_height = cv_img.rows;
+  const int img_width = cv_img.cols;
+  const int img_size = img_channels * img_height * img_width;
+  CHECK_GT(img_channels, 0);
+  CHECK_GT(img_height, 0);
+  CHECK_GT(img_width, 0);
+  datum->set_channels(img_channels);
+  datum->set_height(img_height);
+  datum->set_width(img_width);
   datum->clear_float_data();
   datum->set_encoded(false);
-  int datum_channels = datum->channels();
-  int datum_height = datum->height();
-  int datum_width = datum->width();
-  int datum_size = datum_channels * datum_height * datum_width;
-  std::string buffer(datum_size, ' ');
-  for (int h = 0; h < datum_height; ++h) {
-    const uchar* ptr = cv_img.ptr<uchar>(h);
-    int img_index = 0;
-    for (int w = 0; w < datum_width; ++w) {
-      for (int c = 0; c < datum_channels; ++c) {
-        int datum_index = (c * datum_height + h) * datum_width + w;
-        buffer[datum_index] = static_cast<char>(ptr[img_index++]);
+  datum->mutable_data()->reserve(img_size);
+  datum->mutable_data()->resize(img_size);
+  CHECK_EQ(datum->mutable_data()->size(), img_size);
+  char* mut_data = &datum->mutable_data()->at(0);
+  const int datum_hw_stride = img_height * img_width;
+  for (int h = 0; h < img_height; ++h) {
+    const int datum_h_offset = h * img_width;
+    const uchar* row_ptr = cv_img.ptr<uchar>(h);
+    int row_index = 0;
+    for (int w = 0; w < img_width; ++w) {
+      int datum_index = datum_h_offset + w;
+      for (int c = 0; c < img_channels; ++c, datum_index += datum_hw_stride) {
+        mut_data[datum_index] = static_cast<char>(row_ptr[row_index++]);
       }
     }
   }
-  datum->set_data(buffer);
 }
 #endif  // USE_OPENCV
 }  // namespace caffe
diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp
index 436f862a556..3c806889b80 100644
--- a/src/caffe/util/math_functions.cpp
+++ b/src/caffe/util/math_functions.cpp
@@ -421,6 +421,22 @@ void caffe_rng_uniform(int n, float a, float b, Dtype* r) {
   }
 }
 
+template <>
+void caffe_rng_uniform<int>(int n, float a, float b, int* r) {
+  CHECK_GE(n, 0);
+  CHECK(r);
+  CHECK_LE(a, b);
+  // NOTE: `boost::uniform_int` uses an inclusive (closed) interval.
+  const int lower = static_cast<int>(a);
+  const int upper = static_cast<int>(b);
+  boost::uniform_int<> incl_range(lower, upper);
+  boost::variate_generator<caffe::rng_t*, boost::uniform_int<> >
+       variate_generator(caffe_rng(), incl_range);
+  for (int i = 0; i < n; ++i) {
+    r[i] = variate_generator();
+  }
+}
+
 template
 void caffe_rng_uniform<float>(int n, float a, float b, float* r);
 

From 491a502d7214e06cc22bd35d7b243d18b9527b79 Mon Sep 17 00:00:00 2001
From: Peter Jin <pjin@nvidia.com>
Date: Thu, 6 Jul 2017 19:02:14 -0700
Subject: [PATCH 2/2] Use versions of `DecodeDatum` which take `cv::Mat` by
 reference.

---
 include/caffe/data_transformer.hpp |  9 +++++----
 include/caffe/util/io.hpp          |  4 +++-
 src/caffe/data_transformer.cpp     | 31 +++++++++++++++---------------
 src/caffe/util/io.cpp              | 18 ++++++++++++-----
 4 files changed, 36 insertions(+), 26 deletions(-)

diff --git a/include/caffe/data_transformer.hpp b/include/caffe/data_transformer.hpp
index db19a79d3c8..75ff509e0d2 100644
--- a/include/caffe/data_transformer.hpp
+++ b/include/caffe/data_transformer.hpp
@@ -238,10 +238,11 @@ class DataTransformer {
 #ifndef CPU_ONLY
   GPUMemory::Workspace mean_values_gpu_;
 #endif
-  shared_ptr<Datum> var_sized_transform_datum_;
-  cv::Mat tmp_rand_resize_img_;
-  cv::Mat tmp_rand_crop_img_;
-  cv::Mat tmp_center_crop_img_;
+  shared_ptr<Datum> varsz_datum_;
+  cv::Mat varsz_orig_img_;
+  cv::Mat varsz_rand_resize_img_;
+  cv::Mat varsz_rand_crop_img_;
+  cv::Mat varsz_center_crop_img_;
 };
 
 }  // namespace caffe
diff --git a/include/caffe/util/io.hpp b/include/caffe/util/io.hpp
index 4c9834bcfed..129330ae0e7 100644
--- a/include/caffe/util/io.hpp
+++ b/include/caffe/util/io.hpp
@@ -142,9 +142,11 @@ cv::Mat ReadImageToCVMat(const string& filename,
 cv::Mat ReadImageToCVMat(const string& filename);
 
 cv::Mat DecodeDatumToCVMatNative(const Datum& datum);
+void DecodeDatumToCVMatNative(const Datum& datum, cv::Mat& img);
 cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color);
+void DecodeDatumToCVMat(const Datum& datum, bool is_color, cv::Mat& img);
 
-cv::Mat DatumToCVMat(const Datum& datum);
+void DatumToCVMat(const Datum& datum, cv::Mat& img);
 void CVMatToDatum(const cv::Mat& cv_img, Datum* datum);
 #endif  // USE_OPENCV
 
diff --git a/src/caffe/data_transformer.cpp b/src/caffe/data_transformer.cpp
index 6ed75f02fcd..03b4e8c0153 100644
--- a/src/caffe/data_transformer.cpp
+++ b/src/caffe/data_transformer.cpp
@@ -37,7 +37,7 @@ DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param, Ph
       mean_values_.push_back(param_.mean_value(c));
     }
   }
-  var_sized_transform_datum_ = make_shared<Datum>();
+  varsz_datum_ = make_shared<Datum>();
 }
 
 #ifdef USE_OPENCV
@@ -170,20 +170,19 @@ vector<int> DataTransformer<Dtype>::var_sized_transforms_shape(
 
 template<typename Dtype>
 shared_ptr<Datum> DataTransformer<Dtype>::VariableSizedTransforms(const Datum& old_datum) {
-  cv::Mat orig_img;
   if (old_datum.encoded()) {
     CHECK(!(param_.force_color() && param_.force_gray()))
         << "cannot set both force_color and force_gray";
     if (param_.force_color() || param_.force_gray()) {
       // If force_color then decode in color otherwise decode in gray.
-      orig_img = DecodeDatumToCVMat(old_datum, param_.force_color());
+      DecodeDatumToCVMat(old_datum, param_.force_color(), varsz_orig_img_);
     } else {
-      orig_img = DecodeDatumToCVMatNative(old_datum);
+      DecodeDatumToCVMatNative(old_datum, varsz_orig_img_);
     }
   } else {
-    orig_img = DatumToCVMat(old_datum);
+    DatumToCVMat(old_datum, varsz_orig_img_);
   }
-  cv::Mat& img = orig_img;
+  cv::Mat& img = varsz_orig_img_;
   if (var_sized_image_random_resize_enabled()) {
     img = var_sized_image_random_resize(img);
   }
@@ -194,7 +193,7 @@ shared_ptr<Datum> DataTransformer<Dtype>::VariableSizedTransforms(const Datum& o
     img = var_sized_image_center_crop(img);
   }
   {
-    Datum* new_datum = var_sized_transform_datum_.get();
+    Datum* new_datum = varsz_datum_.get();
     CVMatToDatum(img, new_datum);
     if (old_datum.has_label()) {
       new_datum->set_label(old_datum.label());
@@ -203,7 +202,7 @@ shared_ptr<Datum> DataTransformer<Dtype>::VariableSizedTransforms(const Datum& o
     }
     new_datum->set_record_id(old_datum.record_id());
   }
-  return var_sized_transform_datum_;
+  return varsz_datum_;
 }
 
 template<typename Dtype>
@@ -268,11 +267,11 @@ cv::Mat& DataTransformer<Dtype>::var_sized_image_random_resize(cv::Mat& img) {
     CHECK_LE(resize_width, img_width)
         << "cannot downsample height without downsampling width";
     cv::resize(
-        img, tmp_rand_resize_img_,
+        img, varsz_rand_resize_img_,
         cv::Size(resize_width, resize_height),
         0.0, 0.0,
         cv::INTER_AREA);
-    return tmp_rand_resize_img_;
+    return varsz_rand_resize_img_;
   } else if (resize_height > img_height || resize_width > img_width) {
     // Upsample with cubic interpolation.
     CHECK_GE(resize_height, img_height)
@@ -280,11 +279,11 @@ cv::Mat& DataTransformer<Dtype>::var_sized_image_random_resize(cv::Mat& img) {
     CHECK_GE(resize_width, img_width)
         << "cannot upsample height without upsampling width";
     cv::resize(
-        img, tmp_rand_resize_img_,
+        img, varsz_rand_resize_img_,
         cv::Size(resize_width, resize_height),
         0.0, 0.0,
         cv::INTER_CUBIC);
-    return tmp_rand_resize_img_;
+    return varsz_rand_resize_img_;
   } else if (resize_height == img_height && resize_width == img_width) {
     return img;
   }
@@ -336,8 +335,8 @@ cv::Mat& DataTransformer<Dtype>::var_sized_image_random_crop(const cv::Mat& img)
   CHECK_NE(crop_offset_w, -1)
       << "uniform random sampling inexplicably failed";
   cv::Rect crop_roi(crop_offset_w, crop_offset_h, crop_size, crop_size);
-  tmp_rand_crop_img_ = img(crop_roi);
-  return tmp_rand_crop_img_;
+  varsz_rand_crop_img_ = img(crop_roi);
+  return varsz_rand_crop_img_;
 }
 
 template<typename Dtype>
@@ -376,8 +375,8 @@ cv::Mat& DataTransformer<Dtype>::var_sized_image_center_crop(const cv::Mat& img)
   const int crop_offset_h = (img_height - crop_size) / 2;
   const int crop_offset_w = (img_width - crop_size) / 2;
   cv::Rect crop_roi(crop_offset_w, crop_offset_h, crop_size, crop_size);
-  tmp_center_crop_img_ = img(crop_roi);
-  return tmp_center_crop_img_;
+  varsz_center_crop_img_ = img(crop_roi);
+  return varsz_center_crop_img_;
 }
 
 #ifndef CPU_ONLY
diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp
index 212d87a2d05..66226b2014b 100644
--- a/src/caffe/util/io.cpp
+++ b/src/caffe/util/io.cpp
@@ -166,6 +166,11 @@ bool ReadFileToDatum(const string& filename, const int label,
 #ifdef USE_OPENCV
 cv::Mat DecodeDatumToCVMatNative(const Datum& datum) {
   cv::Mat cv_img;
+  DecodeDatumToCVMatNative(datum, cv_img);
+  return cv_img;
+}
+
+void DecodeDatumToCVMatNative(const Datum& datum, cv::Mat& cv_img) {
   CHECK(datum.encoded()) << "Datum not encoded";
   const string& data = datum.data();
   std::vector<char> vec_data(data.c_str(), data.c_str() + data.size());
@@ -173,10 +178,15 @@ cv::Mat DecodeDatumToCVMatNative(const Datum& datum) {
   if (!cv_img.data) {
     LOG(ERROR) << "Could not decode datum ";
   }
-  return cv_img;
 }
+
 cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color) {
   cv::Mat cv_img;
+  DecodeDatumToCVMat(datum, is_color, cv_img);
+  return cv_img;
+}
+
+void DecodeDatumToCVMat(const Datum& datum, bool is_color, cv::Mat& cv_img) {
   CHECK(datum.encoded()) << "Datum not encoded";
   const string& data = datum.data();
   std::vector<char> vec_data(data.c_str(), data.c_str() + data.size());
@@ -186,7 +196,6 @@ cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color) {
   if (!cv_img.data) {
     LOG(ERROR) << "Could not decode datum ";
   }
-  return cv_img;
 }
 
 // If Datum is encoded will decoded using DecodeDatumToCVMat and CVMatToDatum
@@ -210,7 +219,7 @@ bool DecodeDatum(Datum* datum, bool is_color) {
   }
 }
 
-cv::Mat DatumToCVMat(const Datum& datum) {
+void DatumToCVMat(const Datum& datum, cv::Mat& img) {
   if (datum.encoded()) {
     LOG(FATAL) << "Datum encoded";
   }
@@ -221,7 +230,7 @@ cv::Mat DatumToCVMat(const Datum& datum) {
   CHECK_GT(datum_channels, 0);
   CHECK_GT(datum_height, 0);
   CHECK_GT(datum_width, 0);
-  cv::Mat img = cv::Mat::zeros(cv::Size(datum_width, datum_height), CV_8UC(datum_channels));
+  img = cv::Mat::zeros(cv::Size(datum_width, datum_height), CV_8UC(datum_channels));
   CHECK_EQ(img.channels(), datum_channels);
   CHECK_EQ(img.rows, datum_height);
   CHECK_EQ(img.cols, datum_width);
@@ -239,7 +248,6 @@ cv::Mat DatumToCVMat(const Datum& datum) {
       }
     }
   }
-  return img;
 }
 
 void CVMatToDatum(const cv::Mat& cv_img, Datum* datum) {