diff --git a/include/caffe/data_transformer.hpp b/include/caffe/data_transformer.hpp index 2608af9fb58..75ff509e0d2 100644 --- a/include/caffe/data_transformer.hpp +++ b/include/caffe/data_transformer.hpp @@ -1,6 +1,12 @@ #ifndef CAFFE_DATA_TRANSFORMER_HPP #define CAFFE_DATA_TRANSFORMER_HPP +#ifdef USE_OPENCV + +#include + +#endif // USE_OPENCV + #include #include @@ -49,6 +55,44 @@ class DataTransformer { void CopyPtrEntry(shared_ptr datum, Dtype* transformed_ptr, size_t& out_sizeof_element, bool output_labels, Dtype* label); + /** + * @brief Whether there are any "variable_sized" transformations defined + * in the data layer's transform_param block. + */ + bool var_sized_transforms_enabled() const; + + /** + * @brief Calculate the final shape from applying the "variable_sized" + * transformations defined in the data layer's transform_param block + * on the provided image, without actually performing any transformations. + * + * @param orig_shape + * The shape of the data to be transformed. + */ + vector var_sized_transforms_shape(const vector& orig_shape) const; + + /** + * @brief Applies "variable_sized" transformations defined in the data layer's + * transform_param block to the data. + * + * @param old_datum + * The source Datum containing data of arbitrary shape. + * @param new_datum + * The destination Datum that will store transformed data of a fixed + * shape. Suitable for other transformations. + */ + shared_ptr VariableSizedTransforms(const Datum& old_datum); + + bool var_sized_image_random_resize_enabled() const; + vector var_sized_image_random_resize_shape(const vector& prev_shape) const; + cv::Mat& var_sized_image_random_resize(cv::Mat& img); + bool var_sized_image_random_crop_enabled() const; + vector var_sized_image_random_crop_shape(const vector& prev_shape) const; + cv::Mat& var_sized_image_random_crop(const cv::Mat& img); + bool var_sized_image_center_crop_enabled() const; + vector var_sized_image_center_crop_shape(const vector& prev_shape) const; + cv::Mat& var_sized_image_center_crop(const cv::Mat& img); + /** * @brief Applies the transformation defined in the data layer's * transform_param block to the data. @@ -137,6 +181,20 @@ class DataTransformer { const std::array& rand); #endif // USE_OPENCV + vector InferDatumShape(const Datum& datum); +#ifdef USE_OPENCV + vector InferCVMatShape(const cv::Mat& img); +#endif // USE_OPENCV + + /** + * @brief Infers the shape of transformed_blob will have when + * the transformation is applied to the data. + * + * @param bottom_shape + * The shape of the data to be transformed. + */ + vector InferBlobShape(const vector& bottom_shape, bool use_gpu = false); + /** * @brief Infers the shape of transformed_blob will have when * the transformation is applied to the data. @@ -180,6 +238,11 @@ class DataTransformer { #ifndef CPU_ONLY GPUMemory::Workspace mean_values_gpu_; #endif + shared_ptr varsz_datum_; + cv::Mat varsz_orig_img_; + cv::Mat varsz_rand_resize_img_; + cv::Mat varsz_rand_crop_img_; + cv::Mat varsz_center_crop_img_; }; } // namespace caffe diff --git a/include/caffe/util/io.hpp b/include/caffe/util/io.hpp index 1a599883ca3..129330ae0e7 100644 --- a/include/caffe/util/io.hpp +++ b/include/caffe/util/io.hpp @@ -142,8 +142,11 @@ cv::Mat ReadImageToCVMat(const string& filename, cv::Mat ReadImageToCVMat(const string& filename); cv::Mat DecodeDatumToCVMatNative(const Datum& datum); +void DecodeDatumToCVMatNative(const Datum& datum, cv::Mat& img); cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color); +void DecodeDatumToCVMat(const Datum& datum, bool is_color, cv::Mat& img); +void DatumToCVMat(const Datum& datum, cv::Mat& img); void CVMatToDatum(const cv::Mat& cv_img, Datum* datum); #endif // USE_OPENCV diff --git a/include/caffe/util/math_functions.hpp b/include/caffe/util/math_functions.hpp index 24235ee9f7b..d77b664e2c6 100644 --- a/include/caffe/util/math_functions.hpp +++ b/include/caffe/util/math_functions.hpp @@ -102,6 +102,9 @@ Dtype caffe_nextafter(const Dtype b); template void caffe_rng_uniform(int n, float a, float b, Dtype* r); +template <> +void caffe_rng_uniform(int n, float a, float b, int* r); + template void caffe_rng_gaussian(int n, float mu, float sigma, Dtype* r); diff --git a/src/caffe/data_transformer.cpp b/src/caffe/data_transformer.cpp index 4ff07b3133b..03b4e8c0153 100644 --- a/src/caffe/data_transformer.cpp +++ b/src/caffe/data_transformer.cpp @@ -1,6 +1,7 @@ #ifdef USE_OPENCV #include +#include #endif // USE_OPENCV @@ -36,6 +37,7 @@ DataTransformer::DataTransformer(const TransformationParameter& param, Ph mean_values_.push_back(param_.mean_value(c)); } } + varsz_datum_ = make_shared(); } #ifdef USE_OPENCV @@ -136,6 +138,247 @@ void DataTransformer::Fill3Randoms(unsigned int *rand) const { } } +template +bool DataTransformer::var_sized_transforms_enabled() const { + return var_sized_image_random_resize_enabled() || + var_sized_image_random_crop_enabled() || + var_sized_image_center_crop_enabled(); +} + +template +vector DataTransformer::var_sized_transforms_shape( + const vector& orig_shape) const { + CHECK_EQ(orig_shape.size(), 4); + // All of the transforms (random resize, random crop, center crop) + // can be enabled, and they operate sequentially, one after the other. + vector shape(orig_shape); + if (var_sized_image_random_resize_enabled()) { + shape = var_sized_image_random_resize_shape(shape); + } + if (var_sized_image_random_crop_enabled()) { + shape = var_sized_image_random_crop_shape(shape); + } + if (var_sized_image_center_crop_enabled()) { + shape = var_sized_image_center_crop_shape(shape); + } + CHECK_NE(shape[2], 0) + << "variable sized transform has invalid output height; did you forget to crop?"; + CHECK_NE(shape[3], 0) + << "variable sized transform has invalid output width; did you forget to crop?"; + return shape; +} + +template +shared_ptr DataTransformer::VariableSizedTransforms(const Datum& old_datum) { + if (old_datum.encoded()) { + CHECK(!(param_.force_color() && param_.force_gray())) + << "cannot set both force_color and force_gray"; + if (param_.force_color() || param_.force_gray()) { + // If force_color then decode in color otherwise decode in gray. + DecodeDatumToCVMat(old_datum, param_.force_color(), varsz_orig_img_); + } else { + DecodeDatumToCVMatNative(old_datum, varsz_orig_img_); + } + } else { + DatumToCVMat(old_datum, varsz_orig_img_); + } + cv::Mat& img = varsz_orig_img_; + if (var_sized_image_random_resize_enabled()) { + img = var_sized_image_random_resize(img); + } + if (var_sized_image_random_crop_enabled()) { + img = var_sized_image_random_crop(img); + } + if (var_sized_image_center_crop_enabled()) { + img = var_sized_image_center_crop(img); + } + { + Datum* new_datum = varsz_datum_.get(); + CVMatToDatum(img, new_datum); + if (old_datum.has_label()) { + new_datum->set_label(old_datum.label()); + } else { + new_datum->clear_label(); + } + new_datum->set_record_id(old_datum.record_id()); + } + return varsz_datum_; +} + +template +bool DataTransformer::var_sized_image_random_resize_enabled() const { + const int resize_lower = param_.var_sz_img_rand_resize_lower(); + const int resize_upper = param_.var_sz_img_rand_resize_upper(); + if (resize_lower == 0 && resize_upper == 0) { + return false; + } else if (resize_lower != 0 && resize_upper != 0) { + return true; + } + LOG(FATAL) + << "random resize 'lower' and 'upper' parameters must either " + "both be zero or both be nonzero"; +} + +template +vector DataTransformer::var_sized_image_random_resize_shape( + const vector& prev_shape) const { + CHECK(var_sized_image_random_resize_enabled()) + << "var sized transform must be enabled"; + CHECK_EQ(prev_shape.size(), 4) + << "input shape should always have 4 axes (NCHW)"; + vector shape(4); + shape[0] = 1; + shape[1] = prev_shape[1]; + // The output of a random resize is itself a variable sized image. + // By itself a random resize cannot produce an image that is valid input for + // downstream transformations, and must instead be terminated by a + // variable-sized crop (either random or center). + shape[2] = 0; + shape[3] = 0; + return shape; +} + +template +cv::Mat& DataTransformer::var_sized_image_random_resize(cv::Mat& img) { + const int resize_lower = param_.var_sz_img_rand_resize_lower(); + const int resize_upper = param_.var_sz_img_rand_resize_upper(); + CHECK_GT(resize_lower, 0) + << "random resize lower bound parameter must be positive"; + CHECK_GT(resize_upper, 0) + << "random resize lower bound parameter must be positive"; + int resize_size = -1; + caffe_rng_uniform( + 1, + static_cast(resize_lower), static_cast(resize_upper), + &resize_size); + CHECK_NE(resize_size, -1) + << "uniform random sampling inexplicably failed"; + const int img_height = img.rows; + const int img_width = img.cols; + const double scale = (img_width >= img_height) ? + ((static_cast(resize_size)) / (static_cast(img_height))) : + ((static_cast(resize_size)) / (static_cast(img_width))); + const int resize_height = static_cast(std::round(scale * static_cast(img_height))); + const int resize_width = static_cast(std::round(scale * static_cast(img_width))); + if (resize_height < img_height || resize_width < img_width) { + // Downsample with pixel area relation interpolation. + CHECK_LE(resize_height, img_height) + << "cannot downsample width without downsampling height"; + CHECK_LE(resize_width, img_width) + << "cannot downsample height without downsampling width"; + cv::resize( + img, varsz_rand_resize_img_, + cv::Size(resize_width, resize_height), + 0.0, 0.0, + cv::INTER_AREA); + return varsz_rand_resize_img_; + } else if (resize_height > img_height || resize_width > img_width) { + // Upsample with cubic interpolation. + CHECK_GE(resize_height, img_height) + << "cannot upsample width without upsampling height"; + CHECK_GE(resize_width, img_width) + << "cannot upsample height without upsampling width"; + cv::resize( + img, varsz_rand_resize_img_, + cv::Size(resize_width, resize_height), + 0.0, 0.0, + cv::INTER_CUBIC); + return varsz_rand_resize_img_; + } else if (resize_height == img_height && resize_width == img_width) { + return img; + } + LOG(FATAL) + << "unreachable random resize shape: (" + << img_width << ", " << img_height << ") => (" + << resize_width << ", " << resize_height << ")"; +} + +template +bool DataTransformer::var_sized_image_random_crop_enabled() const { + const int crop_size = param_.var_sz_img_rand_crop(); + return crop_size != 0; +} + +template +vector DataTransformer::var_sized_image_random_crop_shape( + const vector& prev_shape) const { + CHECK(var_sized_image_random_crop_enabled()) + << "var sized transform must be enabled"; + const int crop_size = param_.var_sz_img_rand_crop(); + CHECK_EQ(prev_shape.size(), 4) + << "input shape should always have 4 axes (NCHW)"; + vector shape(4); + shape[0] = 1; + shape[1] = prev_shape[1]; + shape[2] = crop_size; + shape[3] = crop_size; + return shape; +} + +template +cv::Mat& DataTransformer::var_sized_image_random_crop(const cv::Mat& img) { + const int crop_size = param_.var_sz_img_rand_crop(); + CHECK_GT(crop_size, 0) + << "random crop size parameter must be positive"; + const int img_height = img.rows; + const int img_width = img.cols; + CHECK_GE(img_height, crop_size) + << "crop size parameter must be at least as large as the image height"; + CHECK_GE(img_width, crop_size) + << "crop size parameter must be at least as large as the image width"; + int crop_offset_h = -1; + int crop_offset_w = -1; + caffe_rng_uniform(1, 0.0f, static_cast(img_height - crop_size), &crop_offset_h); + caffe_rng_uniform(1, 0.0f, static_cast(img_width - crop_size), &crop_offset_w); + CHECK_NE(crop_offset_h, -1) + << "uniform random sampling inexplicably failed"; + CHECK_NE(crop_offset_w, -1) + << "uniform random sampling inexplicably failed"; + cv::Rect crop_roi(crop_offset_w, crop_offset_h, crop_size, crop_size); + varsz_rand_crop_img_ = img(crop_roi); + return varsz_rand_crop_img_; +} + +template +bool DataTransformer::var_sized_image_center_crop_enabled() const { + const int crop_size = param_.var_sz_img_center_crop(); + return crop_size != 0; +} + +template +vector DataTransformer::var_sized_image_center_crop_shape( + const vector& prev_shape) const { + CHECK(var_sized_image_center_crop_enabled()) + << "var sized transform must be enabled"; + const int crop_size = param_.var_sz_img_center_crop(); + CHECK_EQ(prev_shape.size(), 4) + << "input shape should always have 4 axes (NCHW)"; + vector shape(4); + shape[0] = 1; + shape[1] = prev_shape[1]; + shape[2] = crop_size; + shape[3] = crop_size; + return shape; +} + +template +cv::Mat& DataTransformer::var_sized_image_center_crop(const cv::Mat& img) { + const int crop_size = param_.var_sz_img_center_crop(); + CHECK_GT(crop_size, 0) + << "center crop size parameter must be positive"; + const int img_height = img.rows; + const int img_width = img.cols; + CHECK_GE(img_height, crop_size) + << "crop size parameter must be at least as large as the image height"; + CHECK_GE(img_width, crop_size) + << "crop size parameter must be at least as large as the image width"; + const int crop_offset_h = (img_height - crop_size) / 2; + const int crop_offset_w = (img_width - crop_size) / 2; + cv::Rect crop_roi(crop_offset_w, crop_offset_h, crop_size, crop_size); + varsz_center_crop_img_ = img(crop_roi); + return varsz_center_crop_img_; +} + #ifndef CPU_ONLY template @@ -657,7 +900,7 @@ void DataTransformer::TransformPtr(const cv::Mat& cv_img, #endif // USE_OPENCV template -vector DataTransformer::InferBlobShape(const Datum& datum, bool use_gpu) { +vector DataTransformer::InferDatumShape(const Datum& datum) { if (datum.encoded()) { #ifdef USE_OPENCV CHECK(!(param_.force_color() && param_.force_gray())) @@ -669,59 +912,77 @@ vector DataTransformer::InferBlobShape(const Datum& datum, bool use_ } else { cv_img = DecodeDatumToCVMatNative(datum); } - // InferBlobShape using the cv::image. - return InferBlobShape(cv_img, use_gpu); + // Infer shape using the cv::image. + return InferCVMatShape(cv_img); #else LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV."; #endif // USE_OPENCV } - const int crop_size = param_.crop_size(); const int datum_channels = datum.channels(); const int datum_height = datum.height(); const int datum_width = datum.width(); - // Check dimensions. - CHECK_GT(datum_channels, 0); - CHECK_GE(datum_height, crop_size); - CHECK_GE(datum_width, crop_size); - // Build BlobShape. - vector shape(4); - shape[0] = 1; - shape[1] = datum_channels; - // if using GPU transform, don't crop - if (use_gpu) { - shape[2] = datum_height; - shape[3] = datum_width; - } else { - shape[2] = (crop_size) ? crop_size : datum_height; - shape[3] = (crop_size) ? crop_size : datum_width; - } - return shape; + vector datum_shape(4); + datum_shape[0] = 1; + datum_shape[1] = datum_channels; + datum_shape[2] = datum_height; + datum_shape[3] = datum_width; + return datum_shape; } #ifdef USE_OPENCV template -vector DataTransformer::InferBlobShape(const cv::Mat& cv_img, bool use_gpu) { - const int crop_size = param_.crop_size(); +vector DataTransformer::InferCVMatShape(const cv::Mat& cv_img) { const int img_channels = cv_img.channels(); const int img_height = cv_img.rows; const int img_width = cv_img.cols; - // Check dimensions. - CHECK_GT(img_channels, 0); - CHECK_GE(img_height, crop_size); - CHECK_GE(img_width, crop_size); - // Build BlobShape. vector shape(4); shape[0] = 1; shape[1] = img_channels; + shape[2] = img_height; + shape[3] = img_width; + return shape; +} + +#endif // USE_OPENCV + +template +vector DataTransformer::InferBlobShape(const vector& bottom_shape, bool use_gpu) { + const int crop_size = param_.crop_size(); + CHECK_EQ(bottom_shape.size(), 4); + CHECK_EQ(bottom_shape[0], 1); + const int bottom_channels = bottom_shape[1]; + const int bottom_height = bottom_shape[2]; + const int bottom_width = bottom_shape[3]; + // Check dimensions. + CHECK_GT(bottom_channels, 0); + CHECK_GE(bottom_height, crop_size); + CHECK_GE(bottom_width, crop_size); + // Build BlobShape. + vector top_shape(4); + top_shape[0] = 1; + top_shape[1] = bottom_channels; + // if using GPU transform, don't crop if (use_gpu) { - shape[2] = img_height; - shape[3] = img_width; + top_shape[2] = bottom_height; + top_shape[3] = bottom_width; } else { - shape[2] = (crop_size) ? crop_size : img_height; - shape[3] = (crop_size) ? crop_size : img_width; + top_shape[2] = (crop_size) ? crop_size : bottom_height; + top_shape[3] = (crop_size) ? crop_size : bottom_width; } - return shape; + return top_shape; +} + +template +vector DataTransformer::InferBlobShape(const Datum& datum, bool use_gpu) { + return InferBlobShape(InferDatumShape(datum), use_gpu); +} + +#ifdef USE_OPENCV + +template +vector DataTransformer::InferBlobShape(const cv::Mat& cv_img, bool use_gpu) { + return InferBlobShape(InferCVMatShape(cv_img), use_gpu); } #endif // USE_OPENCV diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index 7791502a8a3..cd7907da68d 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -1,9 +1,3 @@ -#ifdef USE_OPENCV - -#include - -#endif // USE_OPENCV - #include "caffe/data_transformer.hpp" #include "caffe/layer.hpp" #include "caffe/layers/data_layer.hpp" @@ -171,10 +165,17 @@ DataLayer::DataLayerSetUp(const vector& bottom, const vecto shared_ptr sample_datum = sample_only_ ? sample_reader_->sample() : reader_->sample(); init_offsets(); + // Calculate the variable sized transformed datum shape. + vector sample_datum_shape = this->data_transformers_[0]->InferDatumShape(*sample_datum); + if (this->data_transformers_[0]->var_sized_transforms_enabled()) { + sample_datum_shape = + this->data_transformers_[0]->var_sized_transforms_shape(sample_datum_shape); + } + // Reshape top[0] and prefetch_data according to the batch_size. // Note: all these reshapings here in load_batch are needed only in case of // different datum shapes coming from database. - vector top_shape = this->data_transformers_[0]->InferBlobShape(*sample_datum); + vector top_shape = this->data_transformers_[0]->InferBlobShape(sample_datum_shape); top_shape[0] = batch_size; top[0]->Reshape(top_shape); @@ -226,14 +227,20 @@ void DataLayer::load_batch(Batch* batch, int thread_id, siz shared_ptr datum = reader->full_peek(qid); CHECK(datum); + // Calculate the variable sized transformed datum shape. + vector datum_shape = this->data_transformers_[thread_id]->InferDatumShape(*datum); + if (this->data_transformers_[thread_id]->var_sized_transforms_enabled()) { + datum_shape = this->data_transformers_[thread_id]->var_sized_transforms_shape(datum_shape); + } + // Use data_transformer to infer the expected blob shape from datum. - vector top_shape = this->data_transformers_[thread_id]->InferBlobShape(*datum, + vector top_shape = this->data_transformers_[thread_id]->InferBlobShape(datum_shape, use_gpu_transform); // Reshape batch according to the batch_size. top_shape[0] = batch_size; batch->data_.Reshape(top_shape); if (use_gpu_transform) { - top_shape = this->data_transformers_[thread_id]->InferBlobShape(*datum, false); + top_shape = this->data_transformers_[thread_id]->InferBlobShape(datum_shape, false); top_shape[0] = batch_size; batch->gpu_transformed_data_->Reshape(top_shape); } @@ -261,7 +268,12 @@ void DataLayer::load_batch(Batch* batch, int thread_id, siz size_t current_batch_id = 0UL; size_t item_id; for (size_t entry = 0; entry < batch_size; ++entry) { - datum = reader->full_pop(qid, "Waiting for datum"); + shared_ptr pop_datum = reader->full_pop(qid, "Waiting for datum"); + datum = pop_datum; + // Apply variable-sized transforms. + if (this->data_transformers_[thread_id]->var_sized_transforms_enabled()) { + datum = this->data_transformers_[thread_id]->VariableSizedTransforms(*datum); + } item_id = datum->record_id() % batch_size; if (datum->channels() > 0) { CHECK_EQ(top_shape[1], datum->channels()) @@ -303,7 +315,7 @@ void DataLayer::load_batch(Batch* batch, int thread_id, siz this->data_transformers_[thread_id]->TransformPtrEntry(datum, ptr, rand, this->output_labels_, label_ptr); } - reader->free_push(qid, datum); + reader->free_push(qid, pop_datum); } if (use_gpu_transform) { diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 278ac846cb5..68f9ec9140c 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -490,6 +490,36 @@ message LayerParameter { // Message that stores parameters used to apply transformation // to the data layer's data message TransformationParameter { + // When the images in a batch are of different shapes, we need to preprocess + // them into the same fixed shape, as downstream operations in caffe require + // images within a batch to be of the same shape. + // + // To transform one image of arbitrary shape into an image of fixed shape, + // we allow specifying a sequence of "variable-sized image transforms." + // There are three possible transforms, and it is possible for _all of them_ + // to be enabled at the same time. They are always applied in the same order: + // (1) first random resize, (2) then random crop, (3) finally center crop. + // The last transform must be either a random crop or a center crop. + // + // The three supported transforms are as follows: + // + // 1. Random resize. This takes two parameters, "lower" and "upper," or + // "L" and "U" for short. If the original image has shape (oldW, oldH), + // the shorter side, D = min(oldW, oldH), is calculated. Then a resize + // target size R is chosing uniformly from the interval [L, U], and both + // sides of the original image are resized by a scaling factor R/D to yield + // a new image with shape (R/D * oldW, R/D * oldH). + // + // 2. Random crop. This takes one crop parameter. A square region is randomly + // chosen from the image for cropping. + // + // 3. Center crop. This takes one crop parameter. A square region is chosen + // from the center of the image for cropping. + // + optional uint32 var_sz_img_rand_resize_lower = 10 [default = 0]; + optional uint32 var_sz_img_rand_resize_upper = 11 [default = 0]; + optional uint32 var_sz_img_rand_crop = 12 [default = 0]; + optional uint32 var_sz_img_center_crop = 13 [default = 0]; // For data pre-processing, we can do simple scaling and subtracting the // data mean, if provided. Note that the mean subtraction is always carried // out before scaling. diff --git a/src/caffe/test/test_data_transformer.cpp b/src/caffe/test/test_data_transformer.cpp index c1791f0aee8..02afaa8b2db 100644 --- a/src/caffe/test/test_data_transformer.cpp +++ b/src/caffe/test/test_data_transformer.cpp @@ -341,6 +341,61 @@ TYPED_TEST(DataTransformTest, TestMeanFile) { } } +template +class VarSzTransformsTest : public ::testing::Test { + protected: + VarSzTransformsTest() + : seed_(1701) {} + + void Run( + const TransformationParameter transform_param, + const int expected_height, const int expected_width) { + const bool unique_pixels = false; // pixels are equal to label + const int label = 42; + const int channels = 3; + const int height = 4; + const int width = 5; + + Datum datum; + FillDatum(label, channels, height, width, unique_pixels, &datum); + DataTransformer transformer(transform_param, TEST); + Caffe::set_random_seed(seed_); + transformer.InitRand(); + shared_ptr transformed_datum = transformer.VariableSizedTransforms(datum); + EXPECT_EQ(transformed_datum->channels(), 3); + EXPECT_EQ(transformed_datum->height(), expected_height); + EXPECT_EQ(transformed_datum->width(), expected_width); + const int data_count = transformed_datum->data().size(); + const char* data = &transformed_datum->data().at(0); + for (int j = 0; j < data_count; ++j) { + EXPECT_EQ(static_cast(data[j]), label); + } + } + + int seed_; +}; + +TYPED_TEST_CASE(VarSzTransformsTest, TestDtypesNoFP16); + +TYPED_TEST(VarSzTransformsTest, TestVarSzImgRandomResize) { + TransformationParameter transform_param; + transform_param.set_var_sz_img_rand_resize_lower(2); + transform_param.set_var_sz_img_rand_resize_upper(2); + this->Run(transform_param, 2, 3); +} + +TYPED_TEST(VarSzTransformsTest, TestVarSzImgRandomCrop) { + TransformationParameter transform_param; + transform_param.set_var_sz_img_rand_crop(3); + this->Run(transform_param, 3, 3); +} + +TYPED_TEST(VarSzTransformsTest, TestVarSzImgCenterCrop) { + TransformationParameter transform_param; + transform_param.set_var_sz_img_center_crop(3); + this->Run(transform_param, 3, 3); +} + #ifndef CPU_ONLY // GPU-based transform tests template diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp index 2fb9ea24a95..66226b2014b 100644 --- a/src/caffe/util/io.cpp +++ b/src/caffe/util/io.cpp @@ -166,6 +166,11 @@ bool ReadFileToDatum(const string& filename, const int label, #ifdef USE_OPENCV cv::Mat DecodeDatumToCVMatNative(const Datum& datum) { cv::Mat cv_img; + DecodeDatumToCVMatNative(datum, cv_img); + return cv_img; +} + +void DecodeDatumToCVMatNative(const Datum& datum, cv::Mat& cv_img) { CHECK(datum.encoded()) << "Datum not encoded"; const string& data = datum.data(); std::vector vec_data(data.c_str(), data.c_str() + data.size()); @@ -173,10 +178,15 @@ cv::Mat DecodeDatumToCVMatNative(const Datum& datum) { if (!cv_img.data) { LOG(ERROR) << "Could not decode datum "; } - return cv_img; } + cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color) { cv::Mat cv_img; + DecodeDatumToCVMat(datum, is_color, cv_img); + return cv_img; +} + +void DecodeDatumToCVMat(const Datum& datum, bool is_color, cv::Mat& cv_img) { CHECK(datum.encoded()) << "Datum not encoded"; const string& data = datum.data(); std::vector vec_data(data.c_str(), data.c_str() + data.size()); @@ -186,7 +196,6 @@ cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color) { if (!cv_img.data) { LOG(ERROR) << "Could not decode datum "; } - return cv_img; } // If Datum is encoded will decoded using DecodeDatumToCVMat and CVMatToDatum @@ -210,30 +219,67 @@ bool DecodeDatum(Datum* datum, bool is_color) { } } +void DatumToCVMat(const Datum& datum, cv::Mat& img) { + if (datum.encoded()) { + LOG(FATAL) << "Datum encoded"; + } + const int datum_channels = datum.channels(); + const int datum_height = datum.height(); + const int datum_width = datum.width(); + const int datum_size = datum_channels * datum_height * datum_width; + CHECK_GT(datum_channels, 0); + CHECK_GT(datum_height, 0); + CHECK_GT(datum_width, 0); + img = cv::Mat::zeros(cv::Size(datum_width, datum_height), CV_8UC(datum_channels)); + CHECK_EQ(img.channels(), datum_channels); + CHECK_EQ(img.rows, datum_height); + CHECK_EQ(img.cols, datum_width); + const std::string& datum_buf = datum.data(); + CHECK_EQ(datum_buf.size(), datum_size); + const int datum_hw_stride = datum_height * datum_width; + for (int h = 0; h < datum_height; ++h) { + const int datum_h_offset = h * datum_width; + uchar* img_row_ptr = img.ptr(h); + int img_row_index = 0; + for (int w = 0; w < datum_width; ++w) { + int datum_index = datum_h_offset + w; + for (int c = 0; c < datum_channels; ++c, datum_index += datum_hw_stride) { + img_row_ptr[img_row_index++] = datum_buf[datum_index]; + } + } + } +} + void CVMatToDatum(const cv::Mat& cv_img, Datum* datum) { CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte"; - datum->set_channels(cv_img.channels()); - datum->set_height(cv_img.rows); - datum->set_width(cv_img.cols); - datum->clear_data(); + const int img_channels = cv_img.channels(); + const int img_height = cv_img.rows; + const int img_width = cv_img.cols; + const int img_size = img_channels * img_height * img_width; + CHECK_GT(img_channels, 0); + CHECK_GT(img_height, 0); + CHECK_GT(img_width, 0); + datum->set_channels(img_channels); + datum->set_height(img_height); + datum->set_width(img_width); datum->clear_float_data(); datum->set_encoded(false); - int datum_channels = datum->channels(); - int datum_height = datum->height(); - int datum_width = datum->width(); - int datum_size = datum_channels * datum_height * datum_width; - std::string buffer(datum_size, ' '); - for (int h = 0; h < datum_height; ++h) { - const uchar* ptr = cv_img.ptr(h); - int img_index = 0; - for (int w = 0; w < datum_width; ++w) { - for (int c = 0; c < datum_channels; ++c) { - int datum_index = (c * datum_height + h) * datum_width + w; - buffer[datum_index] = static_cast(ptr[img_index++]); + datum->mutable_data()->reserve(img_size); + datum->mutable_data()->resize(img_size); + CHECK_EQ(datum->mutable_data()->size(), img_size); + char* mut_data = &datum->mutable_data()->at(0); + const int datum_hw_stride = img_height * img_width; + for (int h = 0; h < img_height; ++h) { + const int datum_h_offset = h * img_width; + const uchar* row_ptr = cv_img.ptr(h); + int row_index = 0; + for (int w = 0; w < img_width; ++w) { + int datum_index = datum_h_offset + w; + for (int c = 0; c < img_channels; ++c, datum_index += datum_hw_stride) { + mut_data[datum_index] = static_cast(row_ptr[row_index++]); } } } - datum->set_data(buffer); } #endif // USE_OPENCV } // namespace caffe diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index 436f862a556..3c806889b80 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -421,6 +421,22 @@ void caffe_rng_uniform(int n, float a, float b, Dtype* r) { } } +template <> +void caffe_rng_uniform(int n, float a, float b, int* r) { + CHECK_GE(n, 0); + CHECK(r); + CHECK_LE(a, b); + // NOTE: `boost::uniform_int` uses an inclusive (closed) interval. + const int lower = static_cast(a); + const int upper = static_cast(b); + boost::uniform_int<> incl_range(lower, upper); + boost::variate_generator > + variate_generator(caffe_rng(), incl_range); + for (int i = 0; i < n; ++i) { + r[i] = variate_generator(); + } +} + template void caffe_rng_uniform(int n, float a, float b, float* r);