diff --git a/include/caffe/layer.hpp b/include/caffe/layer.hpp index 30dbfd53758..2f822de8aec 100644 --- a/include/caffe/layer.hpp +++ b/include/caffe/layer.hpp @@ -291,6 +291,8 @@ class Layer { param_propagate_down_[param_id] = value; } + inline Phase phase() { return phase_; } + protected: /** The protobuf that stores the layer parameters */ diff --git a/include/caffe/layers/dropout_layer.hpp b/include/caffe/layers/dropout_layer.hpp index e83143bc3cc..fc19a392101 100644 --- a/include/caffe/layers/dropout_layer.hpp +++ b/include/caffe/layers/dropout_layer.hpp @@ -73,6 +73,7 @@ class DropoutLayer : public NeuronLayer { /// the scale for undropped inputs at train time @f$ 1 / (1 - p) @f$ Dtype scale_; unsigned int uint_thres_; + bool scale_train_; }; } // namespace caffe diff --git a/include/caffe/layers/roi_pooling_layer.hpp b/include/caffe/layers/roi_pooling_layer.hpp new file mode 100644 index 00000000000..a082721ce4c --- /dev/null +++ b/include/caffe/layers/roi_pooling_layer.hpp @@ -0,0 +1,84 @@ +#ifndef CAFFE_ROI_POOLING_LAYER_HPP_ +#define CAFFE_ROI_POOLING_LAYER_HPP_ + +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +/** + * @brief Perform max pooling on regions of interest specified by input, takes + * as input N feature maps and a list of R regions of interest. + * + * ROIPoolingLayer takes 2 inputs and produces 1 output. bottom[0] is + * [N x C x H x W] feature maps on which pooling is performed. bottom[1] is + * [R x 5] containing a list R ROI tuples with batch index and coordinates of + * regions of interest. Each row in bottom[1] is a ROI tuple in format + * [batch_index x1 y1 x2 y2], where batch_index corresponds to the index of + * instance in the first input and x1 y1 x2 y2 are 0-indexed coordinates + * of ROI rectangle (including its boundaries). + * + * For each of the R ROIs, max-pooling is performed over pooled_h x pooled_w + * output bins (specified in roi_pooling_param). The pooling bin sizes are + * adaptively set such that they tile ROI rectangle in the indexed feature + * map. The pooling region of vertical bin ph in [0, pooled_h) is computed as + * + * start_ph (included) = y1 + floor(ph * (y2 - y1 + 1) / pooled_h) + * end_ph (excluded) = y1 + ceil((ph + 1) * (y2 - y1 + 1) / pooled_h) + * + * and similar horizontal bins. + * + * @param param provides ROIPoolingParameter roi_pooling_param, + * with ROIPoolingLayer options: + * - pooled_h. The pooled output height. + * - pooled_w. The pooled output width + * - spatial_scale. Multiplicative spatial scale factor to translate ROI + * coordinates from their input scale to the scale used when pooling. + * + * Fast R-CNN + * Written by Ross Girshick + */ + +template +class ROIPoolingLayer : public Layer { + public: + explicit ROIPoolingLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline const char* type() const { return "ROIPooling"; } + + virtual inline int MinBottomBlobs() const { return 2; } + virtual inline int MaxBottomBlobs() const { return 2; } + virtual inline int MinTopBlobs() const { return 1; } + virtual inline int MaxTopBlobs() const { return 1; } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + int channels_; + int height_; + int width_; + int pooled_height_; + int pooled_width_; + Dtype spatial_scale_; + Blob max_idx_; +}; + +} // namespace caffe + +#endif // CAFFE_ROI_POOLING_LAYER_HPP_ diff --git a/include/caffe/layers/smooth_l1_loss_layer.hpp b/include/caffe/layers/smooth_l1_loss_layer.hpp new file mode 100644 index 00000000000..44082558077 --- /dev/null +++ b/include/caffe/layers/smooth_l1_loss_layer.hpp @@ -0,0 +1,65 @@ +#ifndef CAFFE_SMOOTH_L1_LOSS_LAYER_HPP_ +#define CAFFE_SMOOTH_L1_LOSS_LAYER_HPP_ + +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +#include "caffe/layers/loss_layer.hpp" + +namespace caffe { + +/** + * @brief SmoothL1LossLayer + * + * Fast R-CNN + * Written by Ross Girshick + */ +template +class SmoothL1LossLayer : public LossLayer { + public: + explicit SmoothL1LossLayer(const LayerParameter& param) + : LossLayer(param), diff_() {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline const char* type() const { return "SmoothL1Loss"; } + + virtual inline int ExactNumBottomBlobs() const { return -1; } + virtual inline int MinBottomBlobs() const { return 2; } + virtual inline int MaxBottomBlobs() const { return 4; } + + /** + * Unlike most loss layers, in the SmoothL1LossLayer we can backpropagate + * to both inputs -- override to return true and always allow force_backward. + */ + virtual inline bool AllowForceBackward(const int bottom_index) const { + return true; + } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + Blob diff_; + Blob errors_; + Blob ones_; + bool has_weights_; + Dtype sigma2_; +}; + +} // namespace caffe + +#endif // CAFFE_SMOOTH_L1_LOSS_LAYER_HPP_ diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index 72659a4f44e..18cab43b3d8 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -399,7 +399,6 @@ BOOST_PYTHON_MODULE(_caffe) { bp::def("solver_rank", &Caffe::solver_rank); bp::def("set_solver_rank", &Caffe::set_solver_rank); bp::def("set_multiprocess", &Caffe::set_multiprocess); - bp::def("layer_type_list", &LayerRegistry::LayerTypeList); bp::class_, shared_ptr >, boost::noncopyable >("Net", diff --git a/src/caffe/layers/dropout_layer.cpp b/src/caffe/layers/dropout_layer.cpp index 533ab26c04d..cedd4a14f4c 100644 --- a/src/caffe/layers/dropout_layer.cpp +++ b/src/caffe/layers/dropout_layer.cpp @@ -16,6 +16,7 @@ void DropoutLayer::LayerSetUp(const vector*>& bottom, DCHECK(threshold_ < 1.); scale_ = 1. / (1. - threshold_); uint_thres_ = static_cast(UINT_MAX * threshold_); + scale_train_ = this->layer_param_.dropout_param().scale_train(); } template @@ -37,11 +38,20 @@ void DropoutLayer::Forward_cpu(const vector*>& bottom, if (this->phase_ == TRAIN) { // Create random numbers caffe_rng_bernoulli(count, 1. - threshold_, mask); - for (int i = 0; i < count; ++i) { - top_data[i] = bottom_data[i] * mask[i] * scale_; + if (scale_train_) { + for (int i = 0; i < count; ++i) { + top_data[i] = bottom_data[i] * mask[i] * scale_; + } + } else { + for (int i = 0; i < count; ++i) { + top_data[i] = bottom_data[i] * mask[i]; + } } } else { caffe_copy(bottom[0]->count(), bottom_data, top_data); + if (!scale_train_) { + caffe_scal(count, 1. / scale_, top_data); + } } } @@ -55,11 +65,20 @@ void DropoutLayer::Backward_cpu(const vector*>& top, if (this->phase_ == TRAIN) { const unsigned int* mask = rand_vec_.cpu_data(); const int count = bottom[0]->count(); - for (int i = 0; i < count; ++i) { - bottom_diff[i] = top_diff[i] * mask[i] * scale_; + if (scale_train_) { + for (int i = 0; i < count; ++i) { + bottom_diff[i] = top_diff[i] * mask[i] * scale_; + } + } else { + for (int i = 0; i < count; ++i) { + bottom_diff[i] = top_diff[i] * mask[i]; + } } } else { caffe_copy(top[0]->count(), top_diff, bottom_diff); + if (!scale_train_) { + caffe_scal(top[0]->count(), 1. / scale_, bottom_diff); + } } } } diff --git a/src/caffe/layers/dropout_layer.cu b/src/caffe/layers/dropout_layer.cu index 186c10ca489..f9f93536c5b 100644 --- a/src/caffe/layers/dropout_layer.cu +++ b/src/caffe/layers/dropout_layer.cu @@ -25,12 +25,23 @@ void DropoutLayer::Forward_gpu(const vector*>& bottom, static_cast(rand_vec_.mutable_gpu_data()); caffe_gpu_rng_uniform(count, mask); // set thresholds - // NOLINT_NEXT_LINE(whitespace/operators) - DropoutForward<<>>( - count, bottom_data, mask, uint_thres_, scale_, top_data); + if (scale_train_) { + // NOLINT_NEXT_LINE(whitespace/operators) + DropoutForward<<>>( + count, bottom_data, mask, uint_thres_, scale_, top_data); + } else { + // NOLINT_NEXT_LINE(whitespace/operators) + DropoutForward<<>>( + count, bottom_data, mask, uint_thres_, 1.f, top_data); + } CUDA_POST_KERNEL_CHECK; } else { caffe_copy(count, bottom_data, top_data); + if (!scale_train_) { + caffe_gpu_scal(count, 1. / scale_, top_data); + } } } @@ -54,13 +65,23 @@ void DropoutLayer::Backward_gpu(const vector*>& top, const unsigned int* mask = static_cast(rand_vec_.gpu_data()); const int count = bottom[0]->count(); - // NOLINT_NEXT_LINE(whitespace/operators) - DropoutBackward<<>>( - count, top_diff, mask, uint_thres_, scale_, bottom_diff); + if (scale_train_) { + // NOLINT_NEXT_LINE(whitespace/operators) + DropoutBackward<<>>( + count, top_diff, mask, uint_thres_, scale_, bottom_diff); + } else { + // NOLINT_NEXT_LINE(whitespace/operators) + DropoutBackward<<>>( + count, top_diff, mask, uint_thres_, 1.f, bottom_diff); + } CUDA_POST_KERNEL_CHECK; } else { caffe_copy(top[0]->count(), top_diff, bottom_diff); + if (!scale_train_) { + caffe_gpu_scal(top[0]->count(), 1. / scale_, bottom_diff); + } } } } diff --git a/src/caffe/layers/roi_pooling_layer.cpp b/src/caffe/layers/roi_pooling_layer.cpp new file mode 100644 index 00000000000..3394478223e --- /dev/null +++ b/src/caffe/layers/roi_pooling_layer.cpp @@ -0,0 +1,168 @@ +#include +#include +#include + +#include "caffe/layers/roi_pooling_layer.hpp" + +using std::max; +using std::min; +using std::floor; +using std::ceil; + +namespace caffe { + +template +void ROIPoolingLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + ROIPoolingParameter roi_pool_param = this->layer_param_.roi_pooling_param(); + CHECK_GT(roi_pool_param.pooled_h(), 0) + << "pooled_h must be > 0"; + CHECK_GT(roi_pool_param.pooled_w(), 0) + << "pooled_w must be > 0"; + pooled_height_ = roi_pool_param.pooled_h(); + pooled_width_ = roi_pool_param.pooled_w(); + spatial_scale_ = roi_pool_param.spatial_scale(); + LOG(INFO) << "Spatial scale: " << spatial_scale_; +} + +template +void ROIPoolingLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + channels_ = bottom[0]->channels(); + height_ = bottom[0]->height(); + width_ = bottom[0]->width(); + top[0]->Reshape(bottom[1]->num(), channels_, pooled_height_, + pooled_width_); + max_idx_.Reshape(bottom[1]->num(), channels_, pooled_height_, + pooled_width_); +} + +template +void ROIPoolingLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + const Dtype* bottom_rois = bottom[1]->cpu_data(); + // Number of ROIs + int num_rois = bottom[1]->num(); + int batch_size = bottom[0]->num(); + int top_count = top[0]->count(); + Dtype* top_data = top[0]->mutable_cpu_data(); + caffe_set(top_count, Dtype(-FLT_MAX), top_data); + int* argmax_data = max_idx_.mutable_cpu_data(); + caffe_set(top_count, -1, argmax_data); + + // For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R + for (int n = 0; n < num_rois; ++n) { + int roi_batch_ind = bottom_rois[0]; + int roi_start_w = round(bottom_rois[1] * spatial_scale_); + int roi_start_h = round(bottom_rois[2] * spatial_scale_); + int roi_end_w = round(bottom_rois[3] * spatial_scale_); + int roi_end_h = round(bottom_rois[4] * spatial_scale_); + CHECK_GE(roi_batch_ind, 0); + CHECK_LT(roi_batch_ind, batch_size); + + int roi_height = max(roi_end_h - roi_start_h + 1, 1); + int roi_width = max(roi_end_w - roi_start_w + 1, 1); + const Dtype bin_size_h = static_cast(roi_height) + / static_cast(pooled_height_); + const Dtype bin_size_w = static_cast(roi_width) + / static_cast(pooled_width_); + + const Dtype* batch_data = bottom_data + bottom[0]->offset(roi_batch_ind); + + for (int c = 0; c < channels_; ++c) { + for (int ph = 0; ph < pooled_height_; ++ph) { + for (int pw = 0; pw < pooled_width_; ++pw) { + // Compute pooling region for this output unit: + // start (included) = floor(ph * roi_height / pooled_height_) + // end (excluded) = ceil((ph + 1) * roi_height / pooled_height_) + int hstart = static_cast(floor(static_cast(ph) + * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) + * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) + * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) + * bin_size_w)); + + hstart = min(max(hstart + roi_start_h, 0), height_); + hend = min(max(hend + roi_start_h, 0), height_); + wstart = min(max(wstart + roi_start_w, 0), width_); + wend = min(max(wend + roi_start_w, 0), width_); + + bool is_empty = (hend <= hstart) || (wend <= wstart); + + const int pool_index = ph * pooled_width_ + pw; + if (is_empty) { + top_data[pool_index] = 0; + argmax_data[pool_index] = -1; + } + + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int index = h * width_ + w; + if (batch_data[index] > top_data[pool_index]) { + top_data[pool_index] = batch_data[index]; + argmax_data[pool_index] = index; + } + } + } + } + } + // Increment all data pointers by one channel + batch_data += bottom[0]->offset(0, 1); + top_data += top[0]->offset(0, 1); + argmax_data += max_idx_.offset(0, 1); + } + // Increment ROI data pointer + bottom_rois += bottom[1]->offset(1); + } +} + +template +void ROIPoolingLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (propagate_down[1]) { + LOG(FATAL) << this->type() + << " Layer cannot backpropagate to roi inputs."; + } + if (!propagate_down[0]) { + return; + } + const Dtype* bottom_rois = bottom[1]->cpu_data(); + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + caffe_set(bottom[0]->count(), Dtype(0.), bottom_diff); + const int* argmax_data = max_idx_.cpu_data(); + const int num_rois = top[0]->num(); + + // Accumulate gradient over all ROIs + for (int roi_n = 0; roi_n < num_rois; ++roi_n) { + int roi_batch_ind = bottom_rois[roi_n * 5]; + // Accumulate gradients over each bin in this ROI + for (int c = 0; c < channels_; ++c) { + for (int ph = 0; ph < pooled_height_; ++ph) { + for (int pw = 0; pw < pooled_width_; ++pw) { + int offset_top = ((roi_n * channels_ + c) * pooled_height_ + ph) + * pooled_width_ + pw; + int argmax_index = argmax_data[offset_top]; + if (argmax_index >= 0) { + int offset_bottom = (roi_batch_ind * channels_ + c) * height_ + * width_ + argmax_index; + bottom_diff[offset_bottom] += top_diff[offset_top]; + } + } + } + } + } +} + + +#ifdef CPU_ONLY +STUB_GPU(ROIPoolingLayer); +#endif + +INSTANTIATE_CLASS(ROIPoolingLayer); +REGISTER_LAYER_CLASS(ROIPooling); + +} // namespace caffe diff --git a/src/caffe/layers/roi_pooling_layer.cu b/src/caffe/layers/roi_pooling_layer.cu new file mode 100644 index 00000000000..ea699a2d1d4 --- /dev/null +++ b/src/caffe/layers/roi_pooling_layer.cu @@ -0,0 +1,184 @@ +#include +#include +#include + +#include "caffe/layers/roi_pooling_layer.hpp" + + +using std::max; +using std::min; + +namespace caffe { + +template +__global__ void ROIPoolForward(const int nthreads, const Dtype* bottom_data, + const Dtype spatial_scale, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const Dtype* bottom_rois, Dtype* top_data, int* argmax_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + bottom_rois += n * 5; + int roi_batch_ind = bottom_rois[0]; + int roi_start_w = round(bottom_rois[1] * spatial_scale); + int roi_start_h = round(bottom_rois[2] * spatial_scale); + int roi_end_w = round(bottom_rois[3] * spatial_scale); + int roi_end_h = round(bottom_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + int roi_width = max(roi_end_w - roi_start_w + 1, 1); + int roi_height = max(roi_end_h - roi_start_h + 1, 1); + Dtype bin_size_h = static_cast(roi_height) + / static_cast(pooled_height); + Dtype bin_size_w = static_cast(roi_width) + / static_cast(pooled_width); + + int hstart = static_cast(floor(static_cast(ph) + * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) + * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) + * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) + * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, 0), height); + hend = min(max(hend + roi_start_h, 0), height); + wstart = min(max(wstart + roi_start_w, 0), width); + wend = min(max(wend + roi_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Define an empty pooling region to be zero + Dtype maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + int maxidx = -1; + bottom_data += (roi_batch_ind * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int bottom_index = h * width + w; + if (bottom_data[bottom_index] > maxval) { + maxval = bottom_data[bottom_index]; + maxidx = bottom_index; + } + } + } + top_data[index] = maxval; + argmax_data[index] = maxidx; + } +} + +template +void ROIPoolingLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + const Dtype* bottom_rois = bottom[1]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + int* argmax_data = max_idx_.mutable_gpu_data(); + int count = top[0]->count(); + // NOLINT_NEXT_LINE(whitespace/operators) + ROIPoolForward<<>>( + count, bottom_data, spatial_scale_, channels_, height_, width_, + pooled_height_, pooled_width_, bottom_rois, top_data, argmax_data); + CUDA_POST_KERNEL_CHECK; +} + +template +__global__ void ROIPoolBackward(const int nthreads, const Dtype* top_diff, + const int* argmax_data, const int num_rois, const Dtype spatial_scale, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, Dtype* bottom_diff, + const Dtype* bottom_rois) { + CUDA_KERNEL_LOOP(index, nthreads) { + // (n, c, h, w) coords in bottom data + int w = index % width; + int h = (index / width) % height; + int c = (index / width / height) % channels; + int n = index / width / height / channels; + + Dtype gradient = 0; + // Accumulate gradient over all ROIs that pooled this element + for (int roi_n = 0; roi_n < num_rois; ++roi_n) { + const Dtype* offset_bottom_rois = bottom_rois + roi_n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + // Skip if ROI's batch index doesn't match n + if (n != roi_batch_ind) { + continue; + } + + int roi_start_w = round(offset_bottom_rois[1] * spatial_scale); + int roi_start_h = round(offset_bottom_rois[2] * spatial_scale); + int roi_end_w = round(offset_bottom_rois[3] * spatial_scale); + int roi_end_h = round(offset_bottom_rois[4] * spatial_scale); + + // Skip if ROI doesn't include (h, w) + const bool in_roi = (w >= roi_start_w && w <= roi_end_w && + h >= roi_start_h && h <= roi_end_h); + if (!in_roi) { + continue; + } + + int offset = (roi_n * channels + c) * pooled_height * pooled_width; + const Dtype* offset_top_diff = top_diff + offset; + const int* offset_argmax_data = argmax_data + offset; + + // Compute feasible set of pooled units that could have pooled + // this bottom unit + + // Force malformed ROIs to be 1x1 + int roi_width = max(roi_end_w - roi_start_w + 1, 1); + int roi_height = max(roi_end_h - roi_start_h + 1, 1); + + Dtype bin_size_h = static_cast(roi_height) + / static_cast(pooled_height); + Dtype bin_size_w = static_cast(roi_width) + / static_cast(pooled_width); + + int phstart = floor(static_cast(h - roi_start_h) / bin_size_h); + int phend = ceil(static_cast(h - roi_start_h + 1) / bin_size_h); + int pwstart = floor(static_cast(w - roi_start_w) / bin_size_w); + int pwend = ceil(static_cast(w - roi_start_w + 1) / bin_size_w); + + phstart = min(max(phstart, 0), pooled_height); + phend = min(max(phend, 0), pooled_height); + pwstart = min(max(pwstart, 0), pooled_width); + pwend = min(max(pwend, 0), pooled_width); + + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + if (offset_argmax_data[ph * pooled_width + pw] == (h * width + w)) { + gradient += offset_top_diff[ph * pooled_width + pw]; + } + } + } + } + bottom_diff[index] = gradient; + } +} + +template +void ROIPoolingLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (!propagate_down[0]) { + return; + } + const Dtype* bottom_rois = bottom[1]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + const int count = bottom[0]->count(); + caffe_gpu_set(count, Dtype(0.), bottom_diff); + const int* argmax_data = max_idx_.gpu_data(); + // NOLINT_NEXT_LINE(whitespace/operators) + ROIPoolBackward<<>>( + count, top_diff, argmax_data, top[0]->num(), spatial_scale_, channels_, + height_, width_, pooled_height_, pooled_width_, bottom_diff, bottom_rois); + CUDA_POST_KERNEL_CHECK; +} + +INSTANTIATE_LAYER_GPU_FUNCS(ROIPoolingLayer); + +} // namespace caffe diff --git a/src/caffe/layers/smooth_l1_loss_layer.cpp b/src/caffe/layers/smooth_l1_loss_layer.cpp new file mode 100644 index 00000000000..6ddbc6edaf2 --- /dev/null +++ b/src/caffe/layers/smooth_l1_loss_layer.cpp @@ -0,0 +1,65 @@ +#include + +#include "caffe/layers/smooth_l1_loss_layer.hpp" + +namespace caffe { + +template +void SmoothL1LossLayer::LayerSetUp( + const vector*>& bottom, const vector*>& top) { + SmoothL1LossParameter loss_param = this->layer_param_.smooth_l1_loss_param(); + sigma2_ = loss_param.sigma() * loss_param.sigma(); + has_weights_ = (bottom.size() >= 3); + if (has_weights_) { + CHECK_EQ(bottom.size(), 4) << "If weights are used, must specify both " + "inside and outside weights"; + } +} + +template +void SmoothL1LossLayer::Reshape( + const vector*>& bottom, const vector*>& top) { + LossLayer::Reshape(bottom, top); + CHECK_EQ(bottom[0]->channels(), bottom[1]->channels()); + CHECK_EQ(bottom[0]->height(), bottom[1]->height()); + CHECK_EQ(bottom[0]->width(), bottom[1]->width()); + if (has_weights_) { + CHECK_EQ(bottom[0]->channels(), bottom[2]->channels()); + CHECK_EQ(bottom[0]->height(), bottom[2]->height()); + CHECK_EQ(bottom[0]->width(), bottom[2]->width()); + CHECK_EQ(bottom[0]->channels(), bottom[3]->channels()); + CHECK_EQ(bottom[0]->height(), bottom[3]->height()); + CHECK_EQ(bottom[0]->width(), bottom[3]->width()); + } + diff_.Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + errors_.Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + // vector of ones used to sum + ones_.Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + for (int i = 0; i < bottom[0]->count(); ++i) { + ones_.mutable_cpu_data()[i] = Dtype(1); + } +} + +template +void SmoothL1LossLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + NOT_IMPLEMENTED; +} + +template +void SmoothL1LossLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + NOT_IMPLEMENTED; +} + +#ifdef CPU_ONLY +STUB_GPU(SmoothL1LossLayer); +#endif + +INSTANTIATE_CLASS(SmoothL1LossLayer); +REGISTER_LAYER_CLASS(SmoothL1Loss); + +} // namespace caffe diff --git a/src/caffe/layers/smooth_l1_loss_layer.cu b/src/caffe/layers/smooth_l1_loss_layer.cu new file mode 100644 index 00000000000..9ddc9183539 --- /dev/null +++ b/src/caffe/layers/smooth_l1_loss_layer.cu @@ -0,0 +1,116 @@ +#include + +#include "caffe/layers/smooth_l1_loss_layer.hpp" + +namespace caffe { + +template +__global__ void SmoothL1Forward(const int n, const Dtype* in, Dtype* out, + Dtype sigma2) { + // f(x) = 0.5 * (sigma * x)^2 if |x| < 1 / sigma / sigma + // |x| - 0.5 / sigma / sigma otherwise + CUDA_KERNEL_LOOP(index, n) { + Dtype val = in[index]; + Dtype abs_val = abs(val); + if (abs_val < 1.0 / sigma2) { + out[index] = 0.5 * val * val * sigma2; + } else { + out[index] = abs_val - 0.5 / sigma2; + } + } +} + +template +void SmoothL1LossLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + int count = bottom[0]->count(); + caffe_gpu_sub( + count, + bottom[0]->gpu_data(), + bottom[1]->gpu_data(), + diff_.mutable_gpu_data()); // d := b0 - b1 + if (has_weights_) { + // apply "inside" weights + caffe_gpu_mul( + count, + bottom[2]->gpu_data(), + diff_.gpu_data(), + diff_.mutable_gpu_data()); // d := w_in * (b0 - b1) + } + // NOLINT_NEXT_LINE(whitespace/operators) + SmoothL1Forward<<>>( + count, diff_.gpu_data(), errors_.mutable_gpu_data(), sigma2_); + CUDA_POST_KERNEL_CHECK; + + if (has_weights_) { + // apply "outside" weights + caffe_gpu_mul( + count, + bottom[3]->gpu_data(), + errors_.gpu_data(), + errors_.mutable_gpu_data()); // d := w_out * SmoothL1(w_in * (b0 - b1)) + } + + Dtype loss; + caffe_gpu_dot(count, ones_.gpu_data(), errors_.gpu_data(), &loss); + top[0]->mutable_cpu_data()[0] = loss / bottom[0]->num(); +} + +template +__global__ void SmoothL1Backward(const int n, const Dtype* in, Dtype* out, + Dtype sigma2) { + // f'(x) = sigma * sigma * x if |x| < 1 / sigma / sigma + // = sign(x) otherwise + CUDA_KERNEL_LOOP(index, n) { + Dtype val = in[index]; + Dtype abs_val = abs(val); + if (abs_val < 1.0 / sigma2) { + out[index] = sigma2 * val; + } else { + out[index] = (Dtype(0) < val) - (val < Dtype(0)); + } + } +} + +template +void SmoothL1LossLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + // after forwards, diff_ holds w_in * (b0 - b1) + int count = diff_.count(); + // NOLINT_NEXT_LINE(whitespace/operators) + SmoothL1Backward<<>>( + count, diff_.gpu_data(), diff_.mutable_gpu_data(), sigma2_); + CUDA_POST_KERNEL_CHECK; + for (int i = 0; i < 2; ++i) { + if (propagate_down[i]) { + const Dtype sign = (i == 0) ? 1 : -1; + const Dtype alpha = sign * top[0]->cpu_diff()[0] / bottom[i]->num(); + caffe_gpu_axpby( + count, // count + alpha, // alpha + diff_.gpu_data(), // x + Dtype(0), // beta + bottom[i]->mutable_gpu_diff()); // y + if (has_weights_) { + // Scale by "inside" weight + caffe_gpu_mul( + count, + bottom[2]->gpu_data(), + bottom[i]->gpu_diff(), + bottom[i]->mutable_gpu_diff()); + // Scale by "outside" weight + caffe_gpu_mul( + count, + bottom[3]->gpu_data(), + bottom[i]->gpu_diff(), + bottom[i]->mutable_gpu_diff()); + } + } + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(SmoothL1LossLayer); + +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index c96966b589d..3b53aa1fc63 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -308,7 +308,7 @@ message ParamSpec { // NOTE // Update the next available ID when you add a new LayerParameter field. // -// LayerParameter next available layer-specific ID: 147 (last added: recurrent_param) +// LayerParameter next available layer-specific ID: 149 (last added: smooth_l1_loss_param) message LayerParameter { optional string name = 1; // the layer name optional string type = 2; // the layer type @@ -396,8 +396,10 @@ message LayerParameter { optional ReductionParameter reduction_param = 136; optional ReLUParameter relu_param = 123; optional ReshapeParameter reshape_param = 133; + optional ROIPoolingParameter roi_pooling_param = 147; optional ScaleParameter scale_param = 142; optional SigmoidParameter sigmoid_param = 124; + optional SmoothL1LossParameter smooth_l1_loss_param = 148; optional SoftmaxParameter softmax_param = 125; optional SPPParameter spp_param = 132; optional SliceParameter slice_param = 126; @@ -674,6 +676,7 @@ message DataParameter { message DropoutParameter { optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio + optional bool scale_train = 2 [default = true]; // scale train or test phase } // DummyDataLayer fills any number of arbitrarily shaped blobs with random @@ -1070,6 +1073,17 @@ message ReshapeParameter { optional int32 num_axes = 3 [default = -1]; } +// Message that stores parameters used by ROIPoolingLayer +message ROIPoolingParameter { + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pooled_h = 1 [default = 0]; // The pooled output height + optional uint32 pooled_w = 2 [default = 0]; // The pooled output width + // Multiplicative spatial scale factor to translate ROI coords from their + // input scale to the scale used when pooling + optional float spatial_scale = 3 [default = 1]; +} + message ScaleParameter { // The first axis of bottom[0] (the first input Blob) along which to apply // bottom[1] (the second input Blob). May be negative to index from the end @@ -1127,6 +1141,13 @@ message SliceParameter { optional uint32 slice_dim = 1 [default = 1]; } +message SmoothL1LossParameter { + // SmoothL1Loss(x) = + // 0.5 * (sigma * x) ** 2 -- if x < 1.0 / sigma / sigma + // |x| - 0.5 / sigma / sigma -- otherwise + optional float sigma = 1 [default = 1]; +} + // Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer message SoftmaxParameter { enum Engine { diff --git a/src/caffe/test/test_roi_pooling_layer.cpp b/src/caffe/test/test_roi_pooling_layer.cpp new file mode 100644 index 00000000000..f678bd1d10a --- /dev/null +++ b/src/caffe/test/test_roi_pooling_layer.cpp @@ -0,0 +1,161 @@ +#include +#include +#include +#include +#include +#include + +#include "boost/scoped_ptr.hpp" +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layers/roi_pooling_layer.hpp" +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +using boost::scoped_ptr; + +namespace caffe { + +template +class ROIPoolingLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + ROIPoolingLayerTest() + : blob_bottom_data_(new Blob(2, 2, 6, 8)), + blob_bottom_rois_(new Blob(4, 5, 1, 1)), + blob_top_data_(new Blob()), + blob_bottom_data_2_(new Blob(2, 3, 12, 20)), + blob_bottom_rois_2_(new Blob(1, 5, 1, 1)), + blob_top_data_2_(new Blob()) { + // fill the values + FillerParameter filler_param; + filler_param.set_std(10); + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_data_); + // for (int i = 0; i < blob_bottom_data_->count(); ++i) { + // blob_bottom_data_->mutable_cpu_data()[i] = i; + // } + blob_bottom_vec_.push_back(blob_bottom_data_); + int i = 0; + blob_bottom_rois_->mutable_cpu_data()[0 + 5*i] = 0; + blob_bottom_rois_->mutable_cpu_data()[1 + 5*i] = 0; // x1 < 8 + blob_bottom_rois_->mutable_cpu_data()[2 + 5*i] = 0; // y1 < 6 + blob_bottom_rois_->mutable_cpu_data()[3 + 5*i] = 7; // x2 < 8 + blob_bottom_rois_->mutable_cpu_data()[4 + 5*i] = 5; // y2 < 6 + i = 1; + blob_bottom_rois_->mutable_cpu_data()[0 + 5*i] = 1; + blob_bottom_rois_->mutable_cpu_data()[1 + 5*i] = 6; // x1 < 8 + blob_bottom_rois_->mutable_cpu_data()[2 + 5*i] = 2; // y1 < 6 + blob_bottom_rois_->mutable_cpu_data()[3 + 5*i] = 7; // x2 < 8 + blob_bottom_rois_->mutable_cpu_data()[4 + 5*i] = 5; // y2 < 6 + i = 2; + blob_bottom_rois_->mutable_cpu_data()[0 + 5*i] = 1; + blob_bottom_rois_->mutable_cpu_data()[1 + 5*i] = 3; // x1 < 8 + blob_bottom_rois_->mutable_cpu_data()[2 + 5*i] = 1; // y1 < 6 + blob_bottom_rois_->mutable_cpu_data()[3 + 5*i] = 6; // x2 < 8 + blob_bottom_rois_->mutable_cpu_data()[4 + 5*i] = 4; // y2 < 6 + i = 3; + blob_bottom_rois_->mutable_cpu_data()[0 + 5*i] = 0; + blob_bottom_rois_->mutable_cpu_data()[1 + 5*i] = 3; // x1 < 8 + blob_bottom_rois_->mutable_cpu_data()[2 + 5*i] = 3; // y1 < 6 + blob_bottom_rois_->mutable_cpu_data()[3 + 5*i] = 3; // x2 < 8 + blob_bottom_rois_->mutable_cpu_data()[4 + 5*i] = 3; // y2 < 6 + + blob_bottom_vec_.push_back(blob_bottom_rois_); + blob_top_vec_.push_back(blob_top_data_); + + filler.Fill(this->blob_bottom_data_2_); + blob_bottom_vec_2_.push_back(blob_bottom_data_2_); + + // Pool over the entire bottom of feature map 1 + blob_bottom_rois_2_->mutable_cpu_data()[0] = 1; + blob_bottom_rois_2_->mutable_cpu_data()[1] = 0; + blob_bottom_rois_2_->mutable_cpu_data()[2] = 0; + blob_bottom_rois_2_->mutable_cpu_data()[3] = 19; + blob_bottom_rois_2_->mutable_cpu_data()[4] = 11; + + blob_bottom_vec_2_.push_back(blob_bottom_rois_2_); + blob_top_vec_2_.push_back(blob_top_data_2_); + } + virtual ~ROIPoolingLayerTest() { + delete blob_bottom_data_; + delete blob_bottom_rois_; + delete blob_top_data_; + delete blob_bottom_data_2_; + delete blob_bottom_rois_2_; + delete blob_top_data_2_; + } + Blob* const blob_bottom_data_; + Blob* const blob_bottom_rois_; + Blob* const blob_top_data_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; + + Blob* const blob_bottom_data_2_; + Blob* const blob_bottom_rois_2_; + Blob* const blob_top_data_2_; + vector*> blob_bottom_vec_2_; + vector*> blob_top_vec_2_; +}; + +TYPED_TEST_CASE(ROIPoolingLayerTest, TestDtypesAndDevices); + +TYPED_TEST(ROIPoolingLayerTest, TestForward) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ROIPoolingParameter* roi_pooling_param = + layer_param.mutable_roi_pooling_param(); + + // 12 x 20 pooling with bin_size_h == 1 && bin_size_w == 1 + roi_pooling_param->set_pooled_h(12); + roi_pooling_param->set_pooled_w(20); + ROIPoolingLayer layer_2(layer_param); + layer_2.SetUp(this->blob_bottom_vec_2_, this->blob_top_vec_2_); + layer_2.Forward(this->blob_bottom_vec_2_, this->blob_top_vec_2_); + for (int i = 0; i < this->blob_top_data_2_->count(); ++i) { + EXPECT_EQ(this->blob_top_data_2_->cpu_data()[i], + this->blob_bottom_data_2_->cpu_data()[i+3*12*20]); + } + + // 6 x 10 pooling with bin_size_h == 2 && bin_size_w == 2 + roi_pooling_param->set_pooled_h(6); + roi_pooling_param->set_pooled_w(10); + ROIPoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_2_, this->blob_top_vec_2_); + layer.Forward(this->blob_bottom_vec_2_, this->blob_top_vec_2_); + int n = 1; + for (int c = 0; c < 3; ++c) { + for (int ph = 0; ph < 6; ++ph) { + for (int pw = 0; pw < 10; ++pw) { + Dtype maxval = -FLT_MAX; + for (int h = 2 * ph; h < 2 * (ph + 1); ++h) { + for (int w = 2 * pw; w < 2 * (pw + 1); ++w) { + maxval = std::max(maxval, this->blob_bottom_data_2_->cpu_data()[ + ((n * 3 + c) * 12 + h) * 20 + w]); + } + } + EXPECT_EQ(this->blob_top_data_2_->cpu_data()[(c * 6 + ph) * 10 + pw], + maxval); + } + } + } +} + +TYPED_TEST(ROIPoolingLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ROIPoolingParameter* roi_pooling_param = + layer_param.mutable_roi_pooling_param(); + roi_pooling_param->set_pooled_h(3); + roi_pooling_param->set_pooled_w(4); + ROIPoolingLayer layer(layer_param); + GradientChecker checker(1e-4, 1e-2); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); +} + +} // namespace caffe diff --git a/src/caffe/test/test_smooth_l1_loss_layer.cpp b/src/caffe/test/test_smooth_l1_loss_layer.cpp new file mode 100644 index 00000000000..7f60a8d3a35 --- /dev/null +++ b/src/caffe/test/test_smooth_l1_loss_layer.cpp @@ -0,0 +1,85 @@ +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layers/smooth_l1_loss_layer.hpp" +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +#ifndef CPU_ONLY +template +class SmoothL1LossLayerTest : public GPUDeviceTest { + protected: + SmoothL1LossLayerTest() + : blob_bottom_data_(new Blob(10, 5, 1, 1)), + blob_bottom_label_(new Blob(10, 5, 1, 1)), + blob_bottom_inside_weights_(new Blob(10, 5, 1, 1)), + blob_bottom_outside_weights_(new Blob(10, 5, 1, 1)), + blob_top_loss_(new Blob()) { + // fill the values + FillerParameter const_filler_param; + const_filler_param.set_value(-1.); + ConstantFiller const_filler(const_filler_param); + FillerParameter filler_param; + GaussianFiller filler(filler_param); + + filler.Fill(this->blob_bottom_data_); + blob_bottom_vec_.push_back(blob_bottom_data_); + + filler.Fill(this->blob_bottom_label_); + blob_bottom_vec_.push_back(blob_bottom_label_); + + filler.Fill(this->blob_bottom_inside_weights_); + blob_bottom_vec_.push_back(blob_bottom_inside_weights_); + + filler.Fill(this->blob_bottom_outside_weights_); + blob_bottom_vec_.push_back(blob_bottom_outside_weights_); + + blob_top_vec_.push_back(blob_top_loss_); + } + virtual ~SmoothL1LossLayerTest() { + delete blob_bottom_data_; + delete blob_bottom_label_; + delete blob_bottom_inside_weights_; + delete blob_bottom_outside_weights_; + delete blob_top_loss_; + } + + Blob* const blob_bottom_data_; + Blob* const blob_bottom_label_; + Blob* const blob_bottom_inside_weights_; + Blob* const blob_bottom_outside_weights_; + Blob* const blob_top_loss_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(SmoothL1LossLayerTest, TestDtypes); + +TYPED_TEST(SmoothL1LossLayerTest, TestGradient) { + LayerParameter layer_param; + SmoothL1LossParameter* loss_param = + layer_param.mutable_smooth_l1_loss_param(); + loss_param->set_sigma(2.4); + + const TypeParam kLossWeight = 3.7; + layer_param.add_loss_weight(kLossWeight); + SmoothL1LossLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + GradientChecker checker(1e-2, 1e-2, 1701); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 1); +} +#endif + +} // namespace caffe