From 87aaf47256cc8010599b5a3f11410d2e807037c8 Mon Sep 17 00:00:00 2001 From: Jonathan L Long Date: Sun, 21 Dec 2014 19:20:36 -0800 Subject: [PATCH 1/9] add BaseConvolutionLayer This provides a common place for code used by ConvolutionLayer and DeconvolutionLayer, simplifying the implementations of both. --- include/caffe/vision_layers.hpp | 87 ++++++++ src/caffe/layers/base_conv_layer.cpp | 289 +++++++++++++++++++++++++++ 2 files changed, 376 insertions(+) create mode 100644 src/caffe/layers/base_conv_layer.cpp diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index c803cd72449..e1bb7abf07c 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -16,6 +16,93 @@ namespace caffe { +template +class BaseConvolutionLayer : public Layer { + public: + explicit BaseConvolutionLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline int MinBottomBlobs() const { return 1; } + virtual inline int MinTopBlobs() const { return 1; } + virtual inline bool EqualNumBottomTopBlobs() const { return true; } + + protected: + // Helper functions that abstract away the column buffer and gemm arguments. + // The last argument in forward_cpu_gemm is so that we can skip the im2col if + // we just called weight_cpu_gemm with the same input. + void forward_cpu_gemm(const Dtype* input, const Dtype* weights, + Dtype* output, bool skip_im2col = false); + void forward_cpu_bias(Dtype* output, const Dtype* bias); + void backward_cpu_gemm(const Dtype* input, const Dtype* weights, + Dtype* output); + void weight_cpu_gemm(const Dtype* input, const Dtype* output, Dtype* + weights); + void backward_cpu_bias(Dtype* bias, const Dtype* input); + + void forward_gpu_gemm(const Dtype* col_input, const Dtype* weights, + Dtype* output, bool skip_im2col = false); + void forward_gpu_bias(Dtype* output, const Dtype* bias); + void backward_gpu_gemm(const Dtype* input, const Dtype* weights, + Dtype* col_output); + void weight_gpu_gemm(const Dtype* col_input, const Dtype* output, Dtype* + weights); + void backward_gpu_bias(Dtype* bias, const Dtype* input); + + // reverse_dimensions should return true iff we are implementing deconv, so + // that conv helpers know which dimensions are which. + virtual bool reverse_dimensions() = 0; + // Compute height_out_ and width_out_ from other parameters. + virtual void compute_output_shape() = 0; + + int kernel_h_, kernel_w_; + int stride_h_, stride_w_; + int num_; + int channels_; + int pad_h_, pad_w_; + int height_, width_; + int group_; + int num_output_; + int height_out_, width_out_; + bool bias_term_; + bool is_1x1_; + + private: + // wrap im2col/col2im so we don't have to remember the (long) argument lists + inline void conv_im2col_cpu(const Dtype* data, Dtype* col_buff) { + im2col_cpu(data, conv_in_channels_, conv_in_height_, conv_in_width_, + kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff); + } + inline void conv_col2im_cpu(const Dtype* col_buff, Dtype* data) { + col2im_cpu(col_buff, conv_in_channels_, conv_in_height_, conv_in_width_, + kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data); + } + inline void conv_im2col_gpu(const Dtype* data, Dtype* col_buff) { + im2col_gpu(data, conv_in_channels_, conv_in_height_, conv_in_width_, + kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff); + } + inline void conv_col2im_gpu(const Dtype* col_buff, Dtype* data) { + col2im_gpu(col_buff, conv_in_channels_, conv_in_height_, conv_in_width_, + kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data); + } + + int conv_out_channels_; + int conv_in_channels_; + int conv_out_spatial_dim_; + int conv_in_height_; + int conv_in_width_; + int kernel_dim_; + int weight_offset_; + int col_offset_; + int output_offset_; + + Blob col_buffer_; + Blob bias_multiplier_; +}; + /** * @brief Convolves the input image with a bank of learned filters, * and (optionally) adds biases. diff --git a/src/caffe/layers/base_conv_layer.cpp b/src/caffe/layers/base_conv_layer.cpp new file mode 100644 index 00000000000..a44867a0cf3 --- /dev/null +++ b/src/caffe/layers/base_conv_layer.cpp @@ -0,0 +1,289 @@ +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + // Configure the kernel size, padding, stride, and inputs. + ConvolutionParameter conv_param = this->layer_param_.convolution_param(); + CHECK(!conv_param.has_kernel_size() != + !(conv_param.has_kernel_h() && conv_param.has_kernel_w())) + << "Filter size is kernel_size OR kernel_h and kernel_w; not both"; + CHECK(conv_param.has_kernel_size() || + (conv_param.has_kernel_h() && conv_param.has_kernel_w())) + << "For non-square filters both kernel_h and kernel_w are required."; + CHECK((!conv_param.has_pad() && conv_param.has_pad_h() + && conv_param.has_pad_w()) + || (!conv_param.has_pad_h() && !conv_param.has_pad_w())) + << "pad is pad OR pad_h and pad_w are required."; + CHECK((!conv_param.has_stride() && conv_param.has_stride_h() + && conv_param.has_stride_w()) + || (!conv_param.has_stride_h() && !conv_param.has_stride_w())) + << "Stride is stride OR stride_h and stride_w are required."; + if (conv_param.has_kernel_size()) { + kernel_h_ = kernel_w_ = conv_param.kernel_size(); + } else { + kernel_h_ = conv_param.kernel_h(); + kernel_w_ = conv_param.kernel_w(); + } + CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero."; + CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero."; + if (!conv_param.has_pad_h()) { + pad_h_ = pad_w_ = conv_param.pad(); + } else { + pad_h_ = conv_param.pad_h(); + pad_w_ = conv_param.pad_w(); + } + if (!conv_param.has_stride_h()) { + stride_h_ = stride_w_ = conv_param.stride(); + } else { + stride_h_ = conv_param.stride_h(); + stride_w_ = conv_param.stride_w(); + } + // Special case: im2col is the identity for 1x1 convolution with stride 1 + // and no padding, so flag for skipping the buffer and transformation. + is_1x1_ = kernel_w_ == 1 && kernel_h_ == 1 + && stride_h_ == 1 && stride_w_ == 1 && pad_h_ == 0 && pad_w_ == 0; + // Configure output channels and groups. + channels_ = bottom[0]->channels(); + num_output_ = this->layer_param_.convolution_param().num_output(); + CHECK_GT(num_output_, 0); + group_ = this->layer_param_.convolution_param().group(); + CHECK_EQ(channels_ % group_, 0); + CHECK_EQ(num_output_ % group_, 0) + << "Number of output should be multiples of group."; + if (reverse_dimensions()) { + conv_out_channels_ = channels_; + conv_in_channels_ = num_output_; + } else { + conv_out_channels_ = num_output_; + conv_in_channels_ = channels_; + } + // Handle the parameters: weights and biases. + // - blobs_[0] holds the filter weights + // - blobs_[1] holds the biases (optional) + bias_term_ = this->layer_param_.convolution_param().bias_term(); + if (this->blobs_.size() > 0) { + LOG(INFO) << "Skipping parameter initialization"; + } else { + if (bias_term_) { + this->blobs_.resize(2); + } else { + this->blobs_.resize(1); + } + // Initialize and fill the weights: + // output channels x input channels per-group x kernel height x kernel width + this->blobs_[0].reset(new Blob( + conv_out_channels_, conv_in_channels_ / group_, kernel_h_, kernel_w_)); + shared_ptr > weight_filler(GetFiller( + this->layer_param_.convolution_param().weight_filler())); + weight_filler->Fill(this->blobs_[0].get()); + // If necessary, initialize and fill the biases: + // 1 x 1 x 1 x output channels + if (bias_term_) { + this->blobs_[1].reset(new Blob(1, 1, 1, num_output_)); + shared_ptr > bias_filler(GetFiller( + this->layer_param_.convolution_param().bias_filler())); + bias_filler->Fill(this->blobs_[1].get()); + } + } + // Propagate gradients to the parameters (as directed by backward pass). + this->param_propagate_down_.resize(this->blobs_.size(), true); +} + +template +void BaseConvolutionLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + num_ = bottom[0]->num(); + height_ = bottom[0]->height(); + width_ = bottom[0]->width(); + CHECK_EQ(bottom[0]->channels(), channels_) << "Input size incompatible with" + " convolution kernel."; + // TODO: generalize to handle inputs of different shapes. + for (int bottom_id = 1; bottom_id < bottom.size(); ++bottom_id) { + CHECK_EQ(num_, bottom[bottom_id]->num()) << "Inputs must have same num."; + CHECK_EQ(channels_, bottom[bottom_id]->channels()) + << "Inputs must have same channels."; + CHECK_EQ(height_, bottom[bottom_id]->height()) + << "Inputs must have same height."; + CHECK_EQ(width_, bottom[bottom_id]->width()) + << "Inputs must have same width."; + } + // Shape the tops. + compute_output_shape(); + for (int top_id = 0; top_id < top.size(); ++top_id) { + top[top_id]->Reshape(num_, num_output_, height_out_, width_out_); + } + if (reverse_dimensions()) { + conv_in_height_ = height_out_; + conv_in_width_ = width_out_; + conv_out_spatial_dim_ = height_ * width_; + } else { + conv_in_height_ = height_; + conv_in_width_ = width_; + conv_out_spatial_dim_ = height_out_ * width_out_; + } + kernel_dim_ = conv_in_channels_ * kernel_h_ * kernel_w_; + weight_offset_ = conv_out_channels_ * kernel_dim_ / group_ / group_; + col_offset_ = kernel_dim_ * conv_out_spatial_dim_ / group_; + output_offset_ = conv_out_channels_ * conv_out_spatial_dim_ / group_; + // The im2col result buffer will only hold one image at a time to avoid + // overly large memory usage. In the special case of 1x1 convolution + // it goes lazily unused to save memory. + if (reverse_dimensions()) { + col_buffer_.Reshape(1, kernel_dim_, height_, width_); + } else { + col_buffer_.Reshape(1, kernel_dim_, height_out_, width_out_); + } + // Set up the all ones "bias multiplier" for adding biases by BLAS + if (bias_term_) { + bias_multiplier_.Reshape(1, 1, 1, height_out_ * width_out_); + caffe_set(bias_multiplier_.count(), Dtype(1), + bias_multiplier_.mutable_cpu_data()); + } +} + +template +void BaseConvolutionLayer::forward_cpu_gemm(const Dtype* input, + const Dtype* weights, Dtype* output, bool skip_im2col) { + const Dtype* col_buff = input; + if (!is_1x1_) { + if (!skip_im2col) { + conv_im2col_cpu(input, col_buffer_.mutable_cpu_data()); + } + col_buff = col_buffer_.cpu_data(); + } + for (int g = 0; g < group_; ++g) { + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, conv_out_channels_ / + group_, conv_out_spatial_dim_, kernel_dim_ / group_, + (Dtype)1., weights + weight_offset_ * g, col_buff + col_offset_ * g, + (Dtype)0., output + output_offset_ * g); + } +} + +template +void BaseConvolutionLayer::forward_cpu_bias(Dtype* output, + const Dtype* bias) { + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num_output_, + height_out_ * width_out_, 1, (Dtype)1., bias, bias_multiplier_.cpu_data(), + (Dtype)1., output); +} + +template +void BaseConvolutionLayer::backward_cpu_gemm(const Dtype* output, + const Dtype* weights, Dtype* input) { + Dtype* col_buff = col_buffer_.mutable_cpu_data(); + if (is_1x1_) { + col_buff = input; + } + for (int g = 0; g < group_; ++g) { + caffe_cpu_gemm(CblasTrans, CblasNoTrans, kernel_dim_ / group_, + conv_out_spatial_dim_, conv_out_channels_ / group_, + (Dtype)1., weights + weight_offset_ * g, output + output_offset_ * g, + (Dtype)0., col_buff + col_offset_ * g); + } + if (!is_1x1_) { + conv_col2im_cpu(col_buff, input); + } +} + +template +void BaseConvolutionLayer::weight_cpu_gemm(const Dtype* input, + const Dtype* output, Dtype* weights) { + const Dtype* col_buff = input; + if (!is_1x1_) { + conv_im2col_cpu(input, col_buffer_.mutable_cpu_data()); + col_buff = col_buffer_.cpu_data(); + } + for (int g = 0; g < group_; ++g) { + caffe_cpu_gemm(CblasNoTrans, CblasTrans, conv_out_channels_ / group_, + kernel_dim_ / group_, conv_out_spatial_dim_, + (Dtype)1., output + output_offset_ * g, col_buff + col_offset_ * g, + (Dtype)1., weights + weight_offset_ * g); + } +} + +template +void BaseConvolutionLayer::backward_cpu_bias(Dtype* bias, + const Dtype* input) { + caffe_cpu_gemv(CblasNoTrans, num_output_, height_out_ * width_out_, 1., + input, bias_multiplier_.cpu_data(), 1., bias); +} + +template +void BaseConvolutionLayer::forward_gpu_gemm(const Dtype* input, + const Dtype* weights, Dtype* output, bool skip_im2col) { + const Dtype* col_buff = input; + if (!is_1x1_) { + if (!skip_im2col) { + conv_im2col_gpu(input, col_buffer_.mutable_gpu_data()); + } + col_buff = col_buffer_.gpu_data(); + } + for (int g = 0; g < group_; ++g) { + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, conv_out_channels_ / + group_, conv_out_spatial_dim_, kernel_dim_ / group_, + (Dtype)1., weights + weight_offset_ * g, col_buff + col_offset_ * g, + (Dtype)0., output + output_offset_ * g); + } +} + +template +void BaseConvolutionLayer::forward_gpu_bias(Dtype* output, + const Dtype* bias) { + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num_output_, + height_out_ * width_out_, 1, (Dtype)1., bias, bias_multiplier_.gpu_data(), + (Dtype)1., output); +} + +template +void BaseConvolutionLayer::backward_gpu_gemm(const Dtype* output, + const Dtype* weights, Dtype* input) { + Dtype* col_buff = col_buffer_.mutable_gpu_data(); + if (is_1x1_) { + col_buff = input; + } + for (int g = 0; g < group_; ++g) { + caffe_gpu_gemm(CblasTrans, CblasNoTrans, kernel_dim_ / group_, + conv_out_spatial_dim_, conv_out_channels_ / group_, + (Dtype)1., weights + weight_offset_ * g, output + output_offset_ * g, + (Dtype)0., col_buff + col_offset_ * g); + } + if (!is_1x1_) { + conv_col2im_gpu(col_buff, input); + } +} + +template +void BaseConvolutionLayer::weight_gpu_gemm(const Dtype* input, + const Dtype* output, Dtype* weights) { + const Dtype* col_buff = input; + if (!is_1x1_) { + conv_im2col_gpu(input, col_buffer_.mutable_gpu_data()); + col_buff = col_buffer_.gpu_data(); + } + for (int g = 0; g < group_; ++g) { + caffe_gpu_gemm(CblasNoTrans, CblasTrans, conv_out_channels_ / group_, + kernel_dim_ / group_, conv_out_spatial_dim_, + (Dtype)1., output + output_offset_ * g, col_buff + col_offset_ * g, + (Dtype)1., weights + weight_offset_ * g); + } +} + +template +void BaseConvolutionLayer::backward_gpu_bias(Dtype* bias, + const Dtype* input) { + caffe_gpu_gemv(CblasNoTrans, num_output_, height_out_ * width_out_, 1., + input, bias_multiplier_.gpu_data(), 1., bias); +} + +INSTANTIATE_CLASS(BaseConvolutionLayer); + +} // namespace caffe From 33a940790f2bf02c727b82050013609c31ebdda1 Mon Sep 17 00:00:00 2001 From: Jonathan L Long Date: Sun, 21 Dec 2014 23:33:35 -0800 Subject: [PATCH 2/9] add CPU_ONLY ifdef guards to BaseConvolutionLayer --- include/caffe/vision_layers.hpp | 4 ++++ src/caffe/layers/base_conv_layer.cpp | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index e1bb7abf07c..4cf7e8b0428 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -43,6 +43,7 @@ class BaseConvolutionLayer : public Layer { weights); void backward_cpu_bias(Dtype* bias, const Dtype* input); +#ifndef CPU_ONLY void forward_gpu_gemm(const Dtype* col_input, const Dtype* weights, Dtype* output, bool skip_im2col = false); void forward_gpu_bias(Dtype* output, const Dtype* bias); @@ -51,6 +52,7 @@ class BaseConvolutionLayer : public Layer { void weight_gpu_gemm(const Dtype* col_input, const Dtype* output, Dtype* weights); void backward_gpu_bias(Dtype* bias, const Dtype* input); +#endif // reverse_dimensions should return true iff we are implementing deconv, so // that conv helpers know which dimensions are which. @@ -80,6 +82,7 @@ class BaseConvolutionLayer : public Layer { col2im_cpu(col_buff, conv_in_channels_, conv_in_height_, conv_in_width_, kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data); } +#ifndef CPU_ONLY inline void conv_im2col_gpu(const Dtype* data, Dtype* col_buff) { im2col_gpu(data, conv_in_channels_, conv_in_height_, conv_in_width_, kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff); @@ -88,6 +91,7 @@ class BaseConvolutionLayer : public Layer { col2im_gpu(col_buff, conv_in_channels_, conv_in_height_, conv_in_width_, kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data); } +#endif int conv_out_channels_; int conv_in_channels_; diff --git a/src/caffe/layers/base_conv_layer.cpp b/src/caffe/layers/base_conv_layer.cpp index a44867a0cf3..dccd5170c11 100644 --- a/src/caffe/layers/base_conv_layer.cpp +++ b/src/caffe/layers/base_conv_layer.cpp @@ -217,6 +217,8 @@ void BaseConvolutionLayer::backward_cpu_bias(Dtype* bias, input, bias_multiplier_.cpu_data(), 1., bias); } +#ifndef CPU_ONLY + template void BaseConvolutionLayer::forward_gpu_gemm(const Dtype* input, const Dtype* weights, Dtype* output, bool skip_im2col) { @@ -284,6 +286,8 @@ void BaseConvolutionLayer::backward_gpu_bias(Dtype* bias, input, bias_multiplier_.gpu_data(), 1., bias); } +#endif // !CPU_ONLY + INSTANTIATE_CLASS(BaseConvolutionLayer); } // namespace caffe From 6ffe3ef3472c05df5d27ea471dfa4a6ec0d7649d Mon Sep 17 00:00:00 2001 From: Jonathan L Long Date: Sun, 21 Dec 2014 19:42:29 -0800 Subject: [PATCH 3/9] rewrite ConvolutionLayer to use BaseConvolutionLayer helpers --- include/caffe/vision_layers.hpp | 36 +---- src/caffe/layers/conv_layer.cpp | 243 ++++---------------------------- src/caffe/layers/conv_layer.cu | 117 +++------------ 3 files changed, 56 insertions(+), 340 deletions(-) diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 4cf7e8b0428..4d93e6c10fc 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -124,7 +124,7 @@ class BaseConvolutionLayer : public Layer { * the output channel N' columns of the output matrix. */ template -class ConvolutionLayer : public Layer { +class ConvolutionLayer : public BaseConvolutionLayer { public: /** * @param param provides ConvolutionParameter convolution_param, @@ -155,18 +155,10 @@ class ConvolutionLayer : public Layer { * kernels + stream parallelism) engines. */ explicit ConvolutionLayer(const LayerParameter& param) - : Layer(param) {} - virtual void LayerSetUp(const vector*>& bottom, - const vector*>& top); - virtual void Reshape(const vector*>& bottom, - const vector*>& top); - + : BaseConvolutionLayer(param) {} virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_CONVOLUTION; } - virtual inline int MinBottomBlobs() const { return 1; } - virtual inline int MinTopBlobs() const { return 1; } - virtual inline bool EqualNumBottomTopBlobs() const { return true; } protected: virtual void Forward_cpu(const vector*>& bottom, @@ -177,30 +169,10 @@ class ConvolutionLayer : public Layer { const vector& propagate_down, const vector*>& bottom); virtual void Backward_gpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom); + virtual inline bool reverse_dimensions() { return false; } + virtual void compute_output_shape(); - int kernel_h_, kernel_w_; - int stride_h_, stride_w_; - int num_; - int channels_; - int pad_h_, pad_w_; - int height_, width_; - int group_; - int num_output_; - int height_out_, width_out_; - bool bias_term_; - bool is_1x1_; - /// M_ is the channel dimension of the output for a single group, which is the - /// leading dimension of the filter matrix. - int M_; - /// K_ is the dimension of an unrolled input for a single group, which is the - /// leading dimension of the data matrix. - int K_; - /// N_ is the spatial dimension of the output, the H x W, which are the last - /// dimensions of the data and filter matrices. - int N_; - Blob col_buffer_; - Blob bias_multiplier_; }; #ifdef USE_CUDNN diff --git a/src/caffe/layers/conv_layer.cpp b/src/caffe/layers/conv_layer.cpp index 0a032025bfb..9fd2fc6a15f 100644 --- a/src/caffe/layers/conv_layer.cpp +++ b/src/caffe/layers/conv_layer.cpp @@ -9,166 +9,26 @@ namespace caffe { template -void ConvolutionLayer::LayerSetUp(const vector*>& bottom, - const vector*>& top) { - // Configure the kernel size, padding, stride, and inputs. - ConvolutionParameter conv_param = this->layer_param_.convolution_param(); - CHECK(!conv_param.has_kernel_size() != - !(conv_param.has_kernel_h() && conv_param.has_kernel_w())) - << "Filter size is kernel_size OR kernel_h and kernel_w; not both"; - CHECK(conv_param.has_kernel_size() || - (conv_param.has_kernel_h() && conv_param.has_kernel_w())) - << "For non-square filters both kernel_h and kernel_w are required."; - CHECK((!conv_param.has_pad() && conv_param.has_pad_h() - && conv_param.has_pad_w()) - || (!conv_param.has_pad_h() && !conv_param.has_pad_w())) - << "pad is pad OR pad_h and pad_w are required."; - CHECK((!conv_param.has_stride() && conv_param.has_stride_h() - && conv_param.has_stride_w()) - || (!conv_param.has_stride_h() && !conv_param.has_stride_w())) - << "Stride is stride OR stride_h and stride_w are required."; - if (conv_param.has_kernel_size()) { - kernel_h_ = kernel_w_ = conv_param.kernel_size(); - } else { - kernel_h_ = conv_param.kernel_h(); - kernel_w_ = conv_param.kernel_w(); - } - CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero."; - CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero."; - if (!conv_param.has_pad_h()) { - pad_h_ = pad_w_ = conv_param.pad(); - } else { - pad_h_ = conv_param.pad_h(); - pad_w_ = conv_param.pad_w(); - } - if (!conv_param.has_stride_h()) { - stride_h_ = stride_w_ = conv_param.stride(); - } else { - stride_h_ = conv_param.stride_h(); - stride_w_ = conv_param.stride_w(); - } - // Special case: im2col is the identity for 1x1 convolution with stride 1 - // and no padding, so flag for skipping the buffer and transformation. - is_1x1_ = kernel_w_ == 1 && kernel_h_ == 1 - && stride_h_ == 1 && stride_w_ == 1 && pad_h_ == 0 && pad_w_ == 0; - // Configure output channels and groups. - channels_ = bottom[0]->channels(); - num_output_ = this->layer_param_.convolution_param().num_output(); - CHECK_GT(num_output_, 0); - group_ = this->layer_param_.convolution_param().group(); - CHECK_EQ(channels_ % group_, 0); - CHECK_EQ(num_output_ % group_, 0) - << "Number of output should be multiples of group."; - // Handle the parameters: weights and biases. - // - blobs_[0] holds the filter weights - // - blobs_[1] holds the biases (optional) - bias_term_ = this->layer_param_.convolution_param().bias_term(); - if (this->blobs_.size() > 0) { - LOG(INFO) << "Skipping parameter initialization"; - } else { - if (bias_term_) { - this->blobs_.resize(2); - } else { - this->blobs_.resize(1); - } - // Initialize and fill the weights: - // output channels x input channels per-group x kernel height x kernel width - this->blobs_[0].reset(new Blob( - num_output_, channels_ / group_, kernel_h_, kernel_w_)); - shared_ptr > weight_filler(GetFiller( - this->layer_param_.convolution_param().weight_filler())); - weight_filler->Fill(this->blobs_[0].get()); - // If necessary, initialize and fill the biases: - // 1 x 1 x 1 x output channels - if (bias_term_) { - this->blobs_[1].reset(new Blob(1, 1, 1, num_output_)); - shared_ptr > bias_filler(GetFiller( - this->layer_param_.convolution_param().bias_filler())); - bias_filler->Fill(this->blobs_[1].get()); - } - } - // Propagate gradients to the parameters (as directed by backward pass). - this->param_propagate_down_.resize(this->blobs_.size(), true); -} - -template -void ConvolutionLayer::Reshape(const vector*>& bottom, - const vector*>& top) { - num_ = bottom[0]->num(); - height_ = bottom[0]->height(); - width_ = bottom[0]->width(); - CHECK_EQ(bottom[0]->channels(), channels_) << "Input size incompatible with" - " convolution kernel."; - // TODO: generalize to handle inputs of different shapes. - for (int bottom_id = 1; bottom_id < bottom.size(); ++bottom_id) { - CHECK_EQ(num_, bottom[bottom_id]->num()) << "Inputs must have same num."; - CHECK_EQ(channels_, bottom[bottom_id]->channels()) - << "Inputs must have same channels."; - CHECK_EQ(height_, bottom[bottom_id]->height()) - << "Inputs must have same height."; - CHECK_EQ(width_, bottom[bottom_id]->width()) - << "Inputs must have same width."; - } - // Shape the tops. - height_out_ = - (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1; - width_out_ = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1; - for (int top_id = 0; top_id < top.size(); ++top_id) { - top[top_id]->Reshape(num_, num_output_, height_out_, width_out_); - } - // Prepare the matrix multiplication computation. - // Each input will be convolved as a single GEMM. - M_ = num_output_ / group_; - K_ = channels_ * kernel_h_ * kernel_w_ / group_; - N_ = height_out_ * width_out_; - // The im2col result buffer will only hold one image at a time to avoid - // overly large memory usage. In the special case of 1x1 convolution - // it goes lazily unused to save memory. - col_buffer_.Reshape( - 1, channels_ * kernel_h_ * kernel_w_, height_out_, width_out_); - // Set up the all ones "bias multiplier" for adding biases by BLAS - if (bias_term_) { - bias_multiplier_.Reshape(1, 1, 1, N_); - caffe_set(N_, Dtype(1), bias_multiplier_.mutable_cpu_data()); - } +void ConvolutionLayer::compute_output_shape() { + this->height_out_ = (this->height_ + 2 * this->pad_h_ - this->kernel_h_) + / this->stride_h_ + 1; + this->width_out_ = (this->width_ + 2 * this->pad_w_ - this->kernel_w_) + / this->stride_w_ + 1; } template void ConvolutionLayer::Forward_cpu(const vector*>& bottom, const vector*>& top) { + const Dtype* weight = this->blobs_[0]->cpu_data(); for (int i = 0; i < bottom.size(); ++i) { const Dtype* bottom_data = bottom[i]->cpu_data(); Dtype* top_data = top[i]->mutable_cpu_data(); - Dtype* col_buff = NULL; - if (!is_1x1_) { - col_buff = col_buffer_.mutable_cpu_data(); - } - const Dtype* weight = this->blobs_[0]->cpu_data(); - int weight_offset = M_ * K_; // number of filter parameters in a group - int col_offset = K_ * N_; // number of values in an input region / column - int top_offset = M_ * N_; // number of values in an output region / column - for (int n = 0; n < num_; ++n) { - // im2col transformation: unroll input regions for filtering - // into column matrix for multplication. - if (!is_1x1_) { - im2col_cpu(bottom_data + bottom[i]->offset(n), channels_, height_, - width_, kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, - col_buff); - } else { // special case for 1x1 convolution - col_buff = bottom[i]->mutable_cpu_data() + bottom[i]->offset(n); - } - // Take inner products for groups. - for (int g = 0; g < group_; ++g) { - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, M_, N_, K_, - (Dtype)1., weight + weight_offset * g, col_buff + col_offset * g, - (Dtype)0., top_data + top[i]->offset(n) + top_offset * g); - } - // Add bias. - if (bias_term_) { - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num_output_, - N_, 1, (Dtype)1., this->blobs_[1]->cpu_data(), - bias_multiplier_.cpu_data(), - (Dtype)1., top_data + top[i]->offset(n)); + for (int n = 0; n < this->num_; ++n) { + this->forward_cpu_gemm(bottom_data + bottom[i]->offset(n), weight, + top_data + top[i]->offset(n)); + if (this->bias_term_) { + const Dtype* bias = this->blobs_[1]->cpu_data(); + this->forward_cpu_bias(top_data + top[i]->offset(n), bias); } } } @@ -177,82 +37,37 @@ void ConvolutionLayer::Forward_cpu(const vector*>& bottom, template void ConvolutionLayer::Backward_cpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { - const Dtype* weight = NULL; - Dtype* weight_diff = NULL; + const Dtype* weight = this->blobs_[0]->cpu_data(); + Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff(); if (this->param_propagate_down_[0]) { - weight = this->blobs_[0]->cpu_data(); - weight_diff = this->blobs_[0]->mutable_cpu_diff(); caffe_set(this->blobs_[0]->count(), Dtype(0), weight_diff); } - Dtype* bias_diff = NULL; - if (bias_term_ && this->param_propagate_down_[1]) { - bias_diff = this->blobs_[1]->mutable_cpu_diff(); - caffe_set(this->blobs_[1]->count(), Dtype(0), bias_diff); + if (this->bias_term_ && this->param_propagate_down_[1]) { + caffe_set(this->blobs_[1]->count(), Dtype(0), + this->blobs_[1]->mutable_cpu_diff()); } - const int weight_offset = M_ * K_; - const int col_offset = K_ * N_; - const int top_offset = M_ * N_; for (int i = 0; i < top.size(); ++i) { - const Dtype* top_diff = NULL; + const Dtype* top_diff = top[i]->cpu_diff(); + const Dtype* bottom_data = bottom[i]->cpu_data(); + Dtype* bottom_diff = bottom[i]->mutable_cpu_diff(); // Bias gradient, if necessary. - if (bias_term_ && this->param_propagate_down_[1]) { - top_diff = top[i]->cpu_diff(); - for (int n = 0; n < num_; ++n) { - caffe_cpu_gemv(CblasNoTrans, num_output_, N_, - 1., top_diff + top[0]->offset(n), - bias_multiplier_.cpu_data(), 1., - bias_diff); + if (this->bias_term_ && this->param_propagate_down_[1]) { + Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff(); + for (int n = 0; n < this->num_; ++n) { + this->backward_cpu_bias(bias_diff, top_diff + top[i]->offset(n)); } } if (this->param_propagate_down_[0] || propagate_down[i]) { - if (!top_diff) { - top_diff = top[i]->cpu_diff(); - } - Dtype* col_buff = NULL; - if (!is_1x1_) { - col_buff = col_buffer_.mutable_cpu_data(); - } - const Dtype* bottom_data = bottom[i]->cpu_data(); - Dtype* bottom_diff = bottom[i]->mutable_cpu_diff(); - for (int n = 0; n < num_; ++n) { - // Since we saved memory in the forward pass by not storing all col - // data, we will need to recompute them. - if (!is_1x1_) { - im2col_cpu(bottom_data + bottom[i]->offset(n), channels_, height_, - width_, kernel_h_, kernel_w_, pad_h_, pad_w_, - stride_h_, stride_w_, col_buff); - } else { - col_buff = bottom[i]->mutable_cpu_data() + bottom[i]->offset(n); - } + for (int n = 0; n < this->num_; ++n) { // gradient w.r.t. weight. Note that we will accumulate diffs. if (this->param_propagate_down_[0]) { - for (int g = 0; g < group_; ++g) { - caffe_cpu_gemm(CblasNoTrans, CblasTrans, M_, K_, N_, - (Dtype)1., top_diff + top[i]->offset(n) + top_offset * g, - col_buff + col_offset * g, (Dtype)1., - weight_diff + weight_offset * g); - } + this->weight_cpu_gemm(bottom_data + bottom[i]->offset(n), + top_diff + top[i]->offset(n), weight_diff); } // gradient w.r.t. bottom data, if necessary. if (propagate_down[i]) { - if (weight == NULL) { - weight = this->blobs_[0]->cpu_data(); - } - if (is_1x1_) { - col_buff = bottom[i]->mutable_cpu_diff() + bottom[i]->offset(n); - } - for (int g = 0; g < group_; ++g) { - caffe_cpu_gemm(CblasTrans, CblasNoTrans, K_, N_, M_, - (Dtype)1., weight + weight_offset * g, - top_diff + top[i]->offset(n) + top_offset * g, - (Dtype)0., col_buff + col_offset * g); - } - // col2im back to the data - if (!is_1x1_) { - col2im_cpu(col_buff, channels_, height_, width_, - kernel_h_, kernel_w_, pad_h_, pad_w_, - stride_h_, stride_w_, bottom_diff + bottom[i]->offset(n)); - } + this->backward_cpu_gemm(top_diff + top[i]->offset(n), weight, + bottom_diff + bottom[i]->offset(n)); } } } diff --git a/src/caffe/layers/conv_layer.cu b/src/caffe/layers/conv_layer.cu index af14facb523..3902fdf3930 100644 --- a/src/caffe/layers/conv_layer.cu +++ b/src/caffe/layers/conv_layer.cu @@ -8,135 +8,64 @@ namespace caffe { -/// @brief refer to CPU forward -- the BLAS implementation is the same. template void ConvolutionLayer::Forward_gpu(const vector*>& bottom, const vector*>& top) { + const Dtype* weight = this->blobs_[0]->gpu_data(); for (int i = 0; i < bottom.size(); ++i) { const Dtype* bottom_data = bottom[i]->gpu_data(); Dtype* top_data = top[i]->mutable_gpu_data(); - Dtype* col_buff = NULL; - if (!is_1x1_) { - col_buff = col_buffer_.mutable_gpu_data(); - } - const Dtype* weight = this->blobs_[0]->gpu_data(); - int weight_offset = M_ * K_; - int col_offset = K_ * N_; - int top_offset = M_ * N_; - for (int n = 0; n < num_; ++n) { - // im2col transformation: unroll input regions for filtering - // into column matrix for multplication. - if (!is_1x1_) { - im2col_gpu(bottom_data + bottom[i]->offset(n), channels_, height_, - width_, kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, - col_buff); - } else { - col_buff = bottom[i]->mutable_gpu_data() + bottom[i]->offset(n); - } - // Take inner products for groups. - for (int g = 0; g < group_; ++g) { - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, M_, N_, K_, - (Dtype)1., weight + weight_offset * g, col_buff + col_offset * g, - (Dtype)0., top_data + top[i]->offset(n) + top_offset * g); - } - // Add bias. - if (bias_term_) { - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num_output_, - N_, 1, (Dtype)1., this->blobs_[1]->gpu_data(), - bias_multiplier_.gpu_data(), - (Dtype)1., top_data + top[i]->offset(n)); + for (int n = 0; n < this->num_; ++n) { + this->forward_gpu_gemm(bottom_data + bottom[i]->offset(n), weight, + top_data + top[i]->offset(n)); + if (this->bias_term_) { + const Dtype* bias = this->blobs_[1]->gpu_data(); + this->forward_gpu_bias(top_data + top[i]->offset(n), bias); } } } } -/// @brief refer to CPU backward -- the BLAS implementation is the same. template void ConvolutionLayer::Backward_gpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { - const Dtype* weight = NULL; - Dtype* weight_diff = NULL; + const Dtype* weight = this->blobs_[0]->gpu_data(); + Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff(); if (this->param_propagate_down_[0]) { - weight = this->blobs_[0]->gpu_data(); - weight_diff = this->blobs_[0]->mutable_gpu_diff(); caffe_gpu_set(this->blobs_[0]->count(), Dtype(0), weight_diff); } - Dtype* bias_diff = NULL; - if (bias_term_ && this->param_propagate_down_[1]) { - bias_diff = this->blobs_[1]->mutable_gpu_diff(); - caffe_gpu_set(this->blobs_[1]->count(), Dtype(0), bias_diff); + if (this->bias_term_ && this->param_propagate_down_[1]) { + caffe_gpu_set(this->blobs_[1]->count(), Dtype(0), + this->blobs_[1]->mutable_gpu_diff()); } - const int weight_offset = M_ * K_; - const int col_offset = K_ * N_; - const int top_offset = M_ * N_; for (int i = 0; i < top.size(); ++i) { - const Dtype* top_diff = NULL; + const Dtype* top_diff = top[i]->gpu_diff(); // Bias gradient, if necessary. - if (bias_term_ && this->param_propagate_down_[1]) { - top_diff = top[i]->gpu_diff(); - for (int n = 0; n < num_; ++n) { - caffe_gpu_gemv(CblasNoTrans, num_output_, N_, - 1., top_diff + top[0]->offset(n), - bias_multiplier_.gpu_data(), 1., - bias_diff); + if (this->bias_term_ && this->param_propagate_down_[1]) { + Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff(); + for (int n = 0; n < this->num_; ++n) { + this->backward_gpu_bias(bias_diff, top_diff + top[i]->offset(n)); } } if (this->param_propagate_down_[0] || propagate_down[i]) { - if (!top_diff) { - top_diff = top[i]->gpu_diff(); - } - Dtype* col_buff = NULL; - if (!is_1x1_) { - col_buff = col_buffer_.mutable_gpu_data(); - } const Dtype* bottom_data = bottom[i]->gpu_data(); Dtype* bottom_diff = bottom[i]->mutable_gpu_diff(); - for (int n = 0; n < num_; ++n) { - // Since we saved memory in the forward pass by not storing all col - // data, we will need to recompute them. - if (!is_1x1_) { - im2col_gpu(bottom_data + bottom[i]->offset(n), channels_, height_, - width_, kernel_h_, kernel_w_, pad_h_, pad_w_, - stride_h_, stride_w_, col_buff); - } else { - col_buff = bottom[i]->mutable_gpu_data() + bottom[i]->offset(n); - } + for (int n = 0; n < this->num_; ++n) { // gradient w.r.t. weight. Note that we will accumulate diffs. if (this->param_propagate_down_[0]) { - for (int g = 0; g < group_; ++g) { - caffe_gpu_gemm(CblasNoTrans, CblasTrans, M_, K_, N_, - (Dtype)1., top_diff + top[i]->offset(n) + top_offset * g, - col_buff + col_offset * g, (Dtype)1., - weight_diff + weight_offset * g); - } + this->weight_gpu_gemm(bottom_data + bottom[i]->offset(n), + top_diff + top[i]->offset(n), weight_diff); } - // gradient w.r.t. bottom data, if necessary + // gradient w.r.t. bottom data, if necessary. if (propagate_down[i]) { - if (weight == NULL) { - weight = this->blobs_[0]->gpu_data(); - } - if (is_1x1_) { - col_buff = bottom[i]->mutable_gpu_diff() + bottom[i]->offset(n); - } - for (int g = 0; g < group_; ++g) { - caffe_gpu_gemm(CblasTrans, CblasNoTrans, K_, N_, M_, - (Dtype)1., weight + weight_offset * g, - top_diff + top[i]->offset(n) + top_offset * g, - (Dtype)0., col_buff + col_offset * g); - } - // col2im back to the data - if (!is_1x1_) { - col2im_gpu(col_buff, channels_, height_, width_, - kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, - bottom_diff + bottom[i]->offset(n)); - } + this->backward_gpu_gemm(top_diff + top[i]->offset(n), weight, + bottom_diff + bottom[i]->offset(n)); } } } } } - INSTANTIATE_LAYER_GPU_FUNCS(ConvolutionLayer); } // namespace caffe From b878285a3b4f8317f58e4655b8b14ef5be2588e2 Mon Sep 17 00:00:00 2001 From: Jonathan L Long Date: Sun, 21 Dec 2014 19:43:36 -0800 Subject: [PATCH 4/9] add DeconvolutionLayer, using BaseConvolutionLayer --- include/caffe/vision_layers.hpp | 20 ++++++++ src/caffe/layers/deconv_layer.cpp | 85 +++++++++++++++++++++++++++++++ src/caffe/layers/deconv_layer.cu | 71 ++++++++++++++++++++++++++ src/caffe/proto/caffe.proto | 3 +- 4 files changed, 178 insertions(+), 1 deletion(-) create mode 100644 src/caffe/layers/deconv_layer.cpp create mode 100644 src/caffe/layers/deconv_layer.cu diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 4d93e6c10fc..646378dea9f 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -171,8 +171,28 @@ class ConvolutionLayer : public BaseConvolutionLayer { const vector& propagate_down, const vector*>& bottom); virtual inline bool reverse_dimensions() { return false; } virtual void compute_output_shape(); +}; +template +class DeconvolutionLayer : public BaseConvolutionLayer { + public: + explicit DeconvolutionLayer(const LayerParameter& param) + : BaseConvolutionLayer(param) {} + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_DECONVOLUTION; + } + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual inline bool reverse_dimensions() { return true; } + virtual void compute_output_shape(); }; #ifdef USE_CUDNN diff --git a/src/caffe/layers/deconv_layer.cpp b/src/caffe/layers/deconv_layer.cpp new file mode 100644 index 00000000000..59114f017bf --- /dev/null +++ b/src/caffe/layers/deconv_layer.cpp @@ -0,0 +1,85 @@ +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void DeconvolutionLayer::compute_output_shape() { + this->height_out_ = this->stride_h_ * (this->height_ - 1) + this->kernel_h_ + - 2 * this->pad_h_; + this->width_out_ = this->stride_w_ * (this->width_ - 1) + this->kernel_w_ + - 2 * this->pad_w_; +} + +template +void DeconvolutionLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* weight = this->blobs_[0]->cpu_data(); + for (int i = 0; i < bottom.size(); ++i) { + const Dtype* bottom_data = bottom[i]->cpu_data(); + Dtype* top_data = top[i]->mutable_cpu_data(); + for (int n = 0; n < this->num_; ++n) { + this->backward_cpu_gemm(bottom_data + bottom[i]->offset(n), weight, + top_data + top[i]->offset(n)); + if (this->bias_term_) { + const Dtype* bias = this->blobs_[1]->cpu_data(); + this->forward_cpu_bias(top_data + top[i]->offset(n), bias); + } + } + } +} + +template +void DeconvolutionLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* weight = this->blobs_[0]->cpu_data(); + Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff(); + if (this->param_propagate_down_[0]) { + caffe_set(this->blobs_[0]->count(), Dtype(0), weight_diff); + } + if (this->bias_term_ && this->param_propagate_down_[1]) { + caffe_set(this->blobs_[1]->count(), Dtype(0), + this->blobs_[1]->mutable_cpu_diff()); + } + for (int i = 0; i < top.size(); ++i) { + const Dtype* top_diff = top[i]->cpu_diff(); + const Dtype* bottom_data = bottom[i]->cpu_data(); + Dtype* bottom_diff = bottom[i]->mutable_cpu_diff(); + // Bias gradient, if necessary. + if (this->bias_term_ && this->param_propagate_down_[1]) { + Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff(); + for (int n = 0; n < this->num_; ++n) { + this->backward_cpu_bias(bias_diff, top_diff + top[i]->offset(n)); + } + } + if (this->param_propagate_down_[0] || propagate_down[i]) { + for (int n = 0; n < this->num_; ++n) { + // Gradient w.r.t. weight. Note that we will accumulate diffs. + if (this->param_propagate_down_[0]) { + this->weight_cpu_gemm(top_diff + top[i]->offset(n), + bottom_data + bottom[i]->offset(n), weight_diff); + } + // Gradient w.r.t. bottom data, if necessary, reusing the column buffer + // we might have just computed above. + if (propagate_down[i]) { + this->forward_cpu_gemm(top_diff + top[i]->offset(n), weight, + bottom_diff + bottom[i]->offset(n), + this->param_propagate_down_[0]); + } + } + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(DeconvolutionLayer); +#endif + +INSTANTIATE_CLASS(DeconvolutionLayer); +REGISTER_LAYER_CLASS(DECONVOLUTION, DeconvolutionLayer); +} // namespace caffe diff --git a/src/caffe/layers/deconv_layer.cu b/src/caffe/layers/deconv_layer.cu new file mode 100644 index 00000000000..9198dd64c72 --- /dev/null +++ b/src/caffe/layers/deconv_layer.cu @@ -0,0 +1,71 @@ +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void DeconvolutionLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* weight = this->blobs_[0]->gpu_data(); + for (int i = 0; i < bottom.size(); ++i) { + const Dtype* bottom_data = bottom[i]->gpu_data(); + Dtype* top_data = top[i]->mutable_gpu_data(); + for (int n = 0; n < this->num_; ++n) { + this->backward_gpu_gemm(bottom_data + bottom[i]->offset(n), weight, + top_data + top[i]->offset(n)); + if (this->bias_term_) { + const Dtype* bias = this->blobs_[1]->gpu_data(); + this->forward_gpu_bias(top_data + top[i]->offset(n), bias); + } + } + } +} + +template +void DeconvolutionLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* weight = this->blobs_[0]->gpu_data(); + Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff(); + if (this->param_propagate_down_[0]) { + caffe_gpu_set(this->blobs_[0]->count(), Dtype(0), weight_diff); + } + if (this->bias_term_ && this->param_propagate_down_[1]) { + caffe_gpu_set(this->blobs_[1]->count(), Dtype(0), + this->blobs_[1]->mutable_gpu_diff()); + } + for (int i = 0; i < top.size(); ++i) { + const Dtype* top_diff = top[i]->gpu_diff(); + const Dtype* bottom_data = bottom[i]->gpu_data(); + Dtype* bottom_diff = bottom[i]->mutable_gpu_diff(); + // Bias gradient, if necessary. + if (this->bias_term_ && this->param_propagate_down_[1]) { + Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff(); + for (int n = 0; n < this->num_; ++n) { + this->backward_gpu_bias(bias_diff, top_diff + top[i]->offset(n)); + } + } + if (this->param_propagate_down_[0] || propagate_down[i]) { + for (int n = 0; n < this->num_; ++n) { + // gradient w.r.t. weight. Note that we will accumulate diffs. + if (this->param_propagate_down_[0]) { + this->weight_gpu_gemm(top_diff + top[i]->offset(n), + bottom_data + bottom[i]->offset(n), weight_diff); + } + // gradient w.r.t. bottom data, if necessary. + if (propagate_down[i]) { + this->forward_gpu_gemm(top_diff + top[i]->offset(n), weight, + bottom_diff + bottom[i]->offset(n)); + } + } + } + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(DeconvolutionLayer); + +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 8086ad66579..e83eefa1622 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -227,7 +227,7 @@ message LayerParameter { // line above the enum. Update the next available ID when you add a new // LayerType. // - // LayerType next available ID: 39 (last added: EXP) + // LayerType next available ID: 40 (last added: DECONVOLUTION) enum LayerType { // "NONE" layer type is 0th enum element so that we don't cause confusion // by defaulting to an existent LayerType (instead, should usually error if @@ -241,6 +241,7 @@ message LayerParameter { CONTRASTIVE_LOSS = 37; CONVOLUTION = 4; DATA = 5; + DECONVOLUTION = 39; DROPOUT = 6; DUMMY_DATA = 32; EUCLIDEAN_LOSS = 7; From ab2393bbcd5c0cc6a0bcaf0f56ca47a65ebd93c9 Mon Sep 17 00:00:00 2001 From: Jonathan L Long Date: Wed, 24 Dec 2014 14:48:40 -0800 Subject: [PATCH 5/9] add util/coords.hpp for coordinate mapping functions This will be useful for keeping track of the coordinate transformations induced by, e.g., convolution and pooling layers. --- include/caffe/util/coords.hpp | 50 +++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 include/caffe/util/coords.hpp diff --git a/include/caffe/util/coords.hpp b/include/caffe/util/coords.hpp new file mode 100644 index 00000000000..8fbadc36a29 --- /dev/null +++ b/include/caffe/util/coords.hpp @@ -0,0 +1,50 @@ +#ifndef CAFFE_UTIL_COORDS_H_ +#define CAFFE_UTIL_COORDS_H_ + +#include +#include +#include + +namespace caffe { + +template +class DiagonalAffineMap { + public: + explicit DiagonalAffineMap(const vector > coefs) + : coefs_(coefs) { } + static DiagonalAffineMap identity(const int nd) { + return DiagonalAffineMap(vector >(nd, make_pair(1, 0))); + } + + inline DiagonalAffineMap compose(const DiagonalAffineMap& other) const { + CHECK_EQ(coefs_.size(), other.coefs_.size()) + << "Attempt to compose DiagonalAffineMaps of different dimensions"; + DiagonalAffineMap out; + transform(coefs_.begin(), coefs_.end(), other.coefs_.begin(), + std::back_inserter(out.coefs_), &compose_coefs); + return out; + } + inline DiagonalAffineMap inv() const { + DiagonalAffineMap out; + transform(coefs_.begin(), coefs_.end(), std::back_inserter(out.coefs_), + &inv_coefs); + return out; + } + inline vector > coefs() { return coefs_; } + + private: + DiagonalAffineMap() { } + static inline pair compose_coefs(pair left, + pair right) { + return make_pair(left.first * right.first, + left.first * right.second + left.second); + } + static inline pair inv_coefs(pair coefs) { + return make_pair(1 / coefs.first, - coefs.second / coefs.first); + } + vector > coefs_; +}; + +} // namespace caffe + +#endif // CAFFE_UTIL_COORDS_H_ From 6a243b1e2f98ddf64392d2f267b0586363a74324 Mon Sep 17 00:00:00 2001 From: Jonathan L Long Date: Thu, 25 Dec 2014 14:31:07 -0800 Subject: [PATCH 6/9] add FilterMap for the coord mapping used by (de)conv and pooling layers --- include/caffe/util/coords.hpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/include/caffe/util/coords.hpp b/include/caffe/util/coords.hpp index 8fbadc36a29..5032fc60abd 100644 --- a/include/caffe/util/coords.hpp +++ b/include/caffe/util/coords.hpp @@ -45,6 +45,17 @@ class DiagonalAffineMap { vector > coefs_; }; +template +DiagonalAffineMap FilterMap(const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_h, const int pad_w) { + vector > coefs; + coefs.push_back(make_pair(stride_h, + static_cast(kernel_h - 1) / 2 - pad_h)); + coefs.push_back(make_pair(stride_w, + static_cast(kernel_w - 1) / 2 - pad_w)); + return DiagonalAffineMap(coefs); +} + } // namespace caffe #endif // CAFFE_UTIL_COORDS_H_ From 98fc02f0bf5c9ed035ae6a117baa817a3e22dc91 Mon Sep 17 00:00:00 2001 From: Jonathan L Long Date: Thu, 25 Dec 2014 14:34:54 -0800 Subject: [PATCH 7/9] implement coord_map for all applicable layers --- include/caffe/common_layers.hpp | 18 ++++++++++++++++++ include/caffe/layer.hpp | 8 ++++++++ include/caffe/neuron_layers.hpp | 3 +++ include/caffe/vision_layers.hpp | 16 +++++++++++++++- 4 files changed, 44 insertions(+), 1 deletion(-) diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp index 9718b825b14..db2310e30db 100644 --- a/include/caffe/common_layers.hpp +++ b/include/caffe/common_layers.hpp @@ -91,6 +91,9 @@ class ConcatLayer : public Layer { } virtual inline int MinBottomBlobs() const { return 2; } virtual inline int ExactNumTopBlobs() const { return 1; } + virtual inline DiagonalAffineMap coord_map() { + return DiagonalAffineMap::identity(2); + } protected: /** @@ -171,6 +174,9 @@ class EltwiseLayer : public Layer { } virtual inline int MinBottomBlobs() const { return 2; } virtual inline int ExactNumTopBlobs() const { return 1; } + virtual inline DiagonalAffineMap coord_map() { + return DiagonalAffineMap::identity(2); + } protected: virtual void Forward_cpu(const vector*>& bottom, @@ -301,6 +307,9 @@ class MVNLayer : public Layer { } virtual inline int ExactNumBottomBlobs() const { return 1; } virtual inline int ExactNumTopBlobs() const { return 1; } + virtual inline DiagonalAffineMap coord_map() { + return DiagonalAffineMap::identity(2); + } protected: virtual void Forward_cpu(const vector*>& bottom, @@ -367,6 +376,9 @@ class SoftmaxLayer : public Layer { } virtual inline int ExactNumBottomBlobs() const { return 1; } virtual inline int ExactNumTopBlobs() const { return 1; } + virtual inline DiagonalAffineMap coord_map() { + return DiagonalAffineMap::identity(2); + } protected: virtual void Forward_cpu(const vector*>& bottom, @@ -431,6 +443,9 @@ class SplitLayer : public Layer { } virtual inline int ExactNumBottomBlobs() const { return 1; } virtual inline int MinTopBlobs() const { return 1; } + virtual inline DiagonalAffineMap coord_map() { + return DiagonalAffineMap::identity(2); + } protected: virtual void Forward_cpu(const vector*>& bottom, @@ -466,6 +481,9 @@ class SliceLayer : public Layer { } virtual inline int ExactNumBottomBlobs() const { return 1; } virtual inline int MinTopBlobs() const { return 2; } + virtual inline DiagonalAffineMap coord_map() { + return DiagonalAffineMap::identity(2); + } protected: virtual void Forward_cpu(const vector*>& bottom, diff --git a/include/caffe/layer.hpp b/include/caffe/layer.hpp index 8a8330bca57..b14b632e694 100644 --- a/include/caffe/layer.hpp +++ b/include/caffe/layer.hpp @@ -3,12 +3,14 @@ #include #include +#include #include #include "caffe/blob.hpp" #include "caffe/common.hpp" #include "caffe/layer_factory.hpp" #include "caffe/proto/caffe.pb.h" +#include "caffe/util/coords.hpp" #include "caffe/util/device_alternate.hpp" namespace caffe { @@ -293,6 +295,12 @@ class Layer { param_propagate_down_[param_id] = value; } + virtual DiagonalAffineMap coord_map() { + NOT_IMPLEMENTED; + // suppress warnings + return DiagonalAffineMap(vector >()); + } + protected: /** The protobuf that stores the layer parameters */ diff --git a/include/caffe/neuron_layers.hpp b/include/caffe/neuron_layers.hpp index 5daeeefe7ae..312a2a0cf38 100644 --- a/include/caffe/neuron_layers.hpp +++ b/include/caffe/neuron_layers.hpp @@ -34,6 +34,9 @@ class NeuronLayer : public Layer { } virtual inline int ExactNumBottomBlobs() const { return 1; } virtual inline int ExactNumTopBlobs() const { return 1; } + virtual inline DiagonalAffineMap coord_map() { + return DiagonalAffineMap::identity(2); + } }; /** diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 646378dea9f..b0f5c06ca74 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -159,6 +159,10 @@ class ConvolutionLayer : public BaseConvolutionLayer { virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_CONVOLUTION; } + virtual inline DiagonalAffineMap coord_map() { + return FilterMap(this->kernel_h_, this->kernel_w_, this->stride_h_, + this->stride_w_, this->pad_h_, this->pad_w_).inv(); + } protected: virtual void Forward_cpu(const vector*>& bottom, @@ -181,7 +185,10 @@ class DeconvolutionLayer : public BaseConvolutionLayer { virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_DECONVOLUTION; } - + virtual inline DiagonalAffineMap coord_map() { + return FilterMap(this->kernel_h_, this->kernel_w_, this->stride_h_, + this->stride_w_, this->pad_h_, this->pad_w_); + } protected: virtual void Forward_cpu(const vector*>& bottom, const vector*>& top); @@ -301,6 +308,9 @@ class LRNLayer : public Layer { } virtual inline int ExactNumBottomBlobs() const { return 1; } virtual inline int ExactNumTopBlobs() const { return 1; } + virtual inline DiagonalAffineMap coord_map() { + return DiagonalAffineMap::identity(2); + } protected: virtual void Forward_cpu(const vector*>& bottom, @@ -385,6 +395,10 @@ class PoolingLayer : public Layer { return (this->layer_param_.pooling_param().pool() == PoolingParameter_PoolMethod_MAX) ? 2 : 1; } + virtual inline DiagonalAffineMap coord_map() { + return FilterMap(kernel_h_, kernel_w_, stride_h_, stride_w_, + pad_h_, pad_w_).inv(); + } protected: virtual void Forward_cpu(const vector*>& bottom, From 1081373e6953cdfe5fc0b29d79a0bf8d063e0561 Mon Sep 17 00:00:00 2001 From: Jonathan L Long Date: Sat, 27 Dec 2014 01:39:15 -0800 Subject: [PATCH 8/9] layers get a pointer back to their owning Net This allows layers to do things that depend, e.g., on net topology. --- include/caffe/layer.hpp | 9 +++++++++ src/caffe/net.cpp | 1 + 2 files changed, 10 insertions(+) diff --git a/include/caffe/layer.hpp b/include/caffe/layer.hpp index b14b632e694..6c768b9bbe2 100644 --- a/include/caffe/layer.hpp +++ b/include/caffe/layer.hpp @@ -15,6 +15,8 @@ namespace caffe { +template class Net; + /** * @brief An interface for the units of computation which can be composed into a * Net. @@ -301,6 +303,10 @@ class Layer { return DiagonalAffineMap(vector >()); } + /** + * @brief Used by Net to give layers a pointer to their owning net. + */ + void set_net(Net* net) { net_ = net; } protected: /** The protobuf that stores the layer parameters */ @@ -314,6 +320,9 @@ class Layer { * the objective function. */ vector loss_; + /** The net to which this layer belongs. */ + Net* net_; + /** @brief Using the CPU device, compute the layer output. */ virtual void Forward_cpu(const vector*>& bottom, const vector*>& top) = 0; diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index e4492cfd6e3..67f8a22b186 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -64,6 +64,7 @@ void Net::Init(const NetParameter& in_param) { const LayerParameter& layer_param = param.layers(layer_id); layers_.push_back(shared_ptr >( LayerRegistry::CreateLayer(layer_param))); + layers_[layer_id]->set_net(this); layer_names_.push_back(layer_param.name()); LOG(INFO) << "Creating Layer " << layer_param.name(); bool need_backward = false; From c1b7bab40c185f264168e39768dcba22994ceaa1 Mon Sep 17 00:00:00 2001 From: Jonathan L Long Date: Sat, 27 Dec 2014 01:44:36 -0800 Subject: [PATCH 9/9] add CropLayer for cropping one blob to another using induced coordinates --- include/caffe/vision_layers.hpp | 35 +++++++++ src/caffe/layers/crop_layer.cpp | 126 ++++++++++++++++++++++++++++++++ src/caffe/layers/crop_layer.cu | 60 +++++++++++++++ src/caffe/proto/caffe.proto | 3 +- 4 files changed, 223 insertions(+), 1 deletion(-) create mode 100644 src/caffe/layers/crop_layer.cpp create mode 100644 src/caffe/layers/crop_layer.cu diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index b0f5c06ca74..29ec37acb8b 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -453,6 +453,41 @@ class CuDNNPoolingLayer : public PoolingLayer { }; #endif +template +class CropLayer : public Layer { + public: + explicit CropLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_CROP; + } + virtual inline int ExactNumBottomBlobs() const { return 2; } + virtual inline int ExactNumTopBlobs() const { return 1; } + virtual inline DiagonalAffineMap coord_map() { + vector > coefs; + coefs.push_back(make_pair(1, - crop_h_)); + coefs.push_back(make_pair(1, - crop_w_)); + return DiagonalAffineMap(coefs); + } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + int crop_h_, crop_w_; +}; + } // namespace caffe #endif // CAFFE_VISION_LAYERS_HPP_ diff --git a/src/caffe/layers/crop_layer.cpp b/src/caffe/layers/crop_layer.cpp new file mode 100644 index 00000000000..f07012f2bb5 --- /dev/null +++ b/src/caffe/layers/crop_layer.cpp @@ -0,0 +1,126 @@ +#include +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/net.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CropLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + // Construct a map from top blobs to layer inds, skipping over in-place + // connections. + map*, int> down_map; + for (int layer_ind = 0; layer_ind < this->net_->top_vecs().size(); + ++layer_ind) { + vector*> tops = this->net_->top_vecs()[layer_ind]; + for (int top_ind = 0; top_ind < tops.size(); ++top_ind) { + if (down_map.find(tops[top_ind]) == down_map.end()) { + down_map[tops[top_ind]] = layer_ind; + } + } + } + // Walk back from the first bottom, keeping track of all the blobs we pass. + set*> path_blobs; + Blob* blob = bottom[0]; + int layer_ind; + // TODO this logic can be simplified if all blobs are tops + path_blobs.insert(blob); + while (down_map.find(blob) != down_map.end()) { + layer_ind = down_map[blob]; + if (this->net_->bottom_vecs()[layer_ind].size() == 0) { + break; + } + blob = this->net_->bottom_vecs()[layer_ind][0]; + path_blobs.insert(blob); + } + // Now walk back from the second bottom, until we find a blob of intersection. + Blob* inter_blob = bottom[1]; + while (path_blobs.find(inter_blob) == path_blobs.end()) { + CHECK(down_map.find(inter_blob) != down_map.end()) + << "Cannot align apparently disconnected blobs."; + layer_ind = down_map[inter_blob]; + CHECK_GT(this->net_->bottom_vecs()[layer_ind].size(), 0) + << "Cannot align apparently disconnected blobs."; + inter_blob = this->net_->bottom_vecs()[layer_ind][0]; + } + // Compute the coord map from the blob of intersection to each bottom. + vector > coord_maps(2, + DiagonalAffineMap::identity(2)); + for (int i = 0; i < 2; ++i) { + for (Blob* blob = bottom[i]; blob != inter_blob; + blob = this->net_->bottom_vecs()[down_map[blob]][0]) { + shared_ptr > layer = this->net_->layers()[down_map[blob]]; + coord_maps[i] = coord_maps[i].compose(layer->coord_map()); + } + } + // Compute the mapping from first bottom coordinates to second. + DiagonalAffineMap crop_map = + coord_maps[1].compose(coord_maps[0].inv()); + for (int i = 0; i < 2; ++i) { + // Check for scale mismatch (unfortunately, CHECK_DOUBLE_EQ does not + // support a message like the other CHECKs). + CHECK_DOUBLE_EQ(crop_map.coefs()[i].first, 1); + CHECK_LE(crop_map.coefs()[i].second, 0) << "Negative crop width."; + // Check that the crop width is an integer. + CHECK_DOUBLE_EQ(crop_map.coefs()[i].second, + round(crop_map.coefs()[i].second)); + } + crop_h_ = - round(crop_map.coefs()[0].second); + crop_w_ = - round(crop_map.coefs()[1].second); +} + +template +void CropLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + top[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), bottom[1]->height(), + bottom[1]->width()); +} + +template +void CropLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + for (int n = 0; n < top[0]->num(); ++n) { + for (int c = 0; c < top[0]->channels(); ++c) { + for (int h = 0; h < top[0]->height(); ++h) { + caffe_copy(top[0]->width(), + bottom_data + bottom[0]->offset(n, c, crop_h_ + h, crop_w_), + top_data + top[0]->offset(n, c, h)); + } + } + } +} + +template +void CropLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + if (propagate_down[0]) { + caffe_set(bottom[0]->count(), static_cast(0), bottom_diff); + for (int n = 0; n < top[0]->num(); ++n) { + for (int c = 0; c < top[0]->channels(); ++c) { + for (int h = 0; h < top[0]->height(); ++h) { + caffe_copy(top[0]->width(), + top_diff + top[0]->offset(n, c, h), + bottom_diff + bottom[0]->offset(n, c, crop_h_ + h, crop_w_)); + } + } + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(CropLayer); +#endif + +INSTANTIATE_CLASS(CropLayer); +REGISTER_LAYER_CLASS(CROP, CropLayer); + +} // namespace caffe diff --git a/src/caffe/layers/crop_layer.cu b/src/caffe/layers/crop_layer.cu new file mode 100644 index 00000000000..2dd3ff95d91 --- /dev/null +++ b/src/caffe/layers/crop_layer.cu @@ -0,0 +1,60 @@ +#include + +#include "caffe/vision_layers.hpp" + +namespace caffe { + +// Copy (one line per thread) from one array to another, with arbitrary +// strides in the last two dimensions. +template +__global__ void copy_kernel(const int n, const int height, const int width, + const int src_outer_stride, const int src_inner_stride, + const int dest_outer_stride, const int dest_inner_stride, + const Dtype* src, Dtype* dest) { + CUDA_KERNEL_LOOP(index, n) { + int src_start = index / height * src_outer_stride + + index % height * src_inner_stride; + int dest_start = index / height * dest_outer_stride + + index % height * dest_inner_stride; + for (int i = 0; i < width; ++i) { + dest[dest_start + i] = src[src_start + i]; + } + } +} + +template +void CropLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + const int lines = top[0]->count() / top[0]->width(); + + // NOLINT_NEXT_LINE(whitespace/operators) + copy_kernel<<>>( + lines, top[0]->height(), top[0]->width(), + bottom[0]->height() * bottom[0]->width(), bottom[0]->width(), + top[0]->height() * top[0]->width(), top[0]->width(), + bottom_data + bottom[0]->offset(0, 0, crop_h_, crop_w_), top_data); +} + +template +void CropLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + const int lines = top[0]->count() / top[0]->width(); + + if (propagate_down[0]) { + caffe_gpu_set(bottom[0]->count(), static_cast(0), bottom_diff); + // NOLINT_NEXT_LINE(whitespace/operators) + copy_kernel<<>>( + lines, top[0]->height(), top[0]->width(), + top[0]->height() * top[0]->width(), top[0]->width(), + bottom[0]->height() * bottom[0]->width(), bottom[0]->width(), + top_diff, bottom_diff + bottom[0]->offset(0, 0, crop_h_, crop_w_)); + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(CropLayer); + +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index e83eefa1622..022a1614aa2 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -227,7 +227,7 @@ message LayerParameter { // line above the enum. Update the next available ID when you add a new // LayerType. // - // LayerType next available ID: 40 (last added: DECONVOLUTION) + // LayerType next available ID: 41 (last added: CROP) enum LayerType { // "NONE" layer type is 0th enum element so that we don't cause confusion // by defaulting to an existent LayerType (instead, should usually error if @@ -240,6 +240,7 @@ message LayerParameter { CONCAT = 3; CONTRASTIVE_LOSS = 37; CONVOLUTION = 4; + CROP = 40; DATA = 5; DECONVOLUTION = 39; DROPOUT = 6;