diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index fa516144d25..a3c8bc0a31f 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -246,6 +246,7 @@ class LRNLayer : public Layer { int pre_pad_; Dtype alpha_; Dtype beta_; + Dtype k_; int num_; int channels_; int height_; diff --git a/matlab/caffe/matcaffe_demo_vgg.m b/matlab/caffe/matcaffe_demo_vgg.m new file mode 100644 index 00000000000..4e5a98eb5f4 --- /dev/null +++ b/matlab/caffe/matcaffe_demo_vgg.m @@ -0,0 +1,96 @@ +function scores = matcaffe_demo_vgg(im, use_gpu, model_def_file, model_file, mean_file) +% scores = matcaffe_demo_vgg(im, use_gpu, model_def_file, model_file, mean_file) +% +% Demo of the matlab wrapper using the networks described in the BMVC-2014 paper "Return of the Devil in the Details: Delving Deep into Convolutional Nets" +% +% INPUT +% im - color image as uint8 HxWx3 +% use_gpu - 1 to use the GPU, 0 to use the CPU +% model_def_file - network configuration (.prototxt file) +% model_file - network weights (.caffemodel file) +% mean_file - mean BGR image as uint8 HxWx3 (.mat file) +% +% OUTPUT +% scores 1000-dimensional ILSVRC score vector +% +% EXAMPLE USAGE +% model_def_file = 'zoo/VGG_CNN_F_deploy.prototxt'; +% model_file = 'zoo/VGG_CNN_F.caffemodel'; +% mean_file = 'zoo/VGG_mean.mat'; +% use_gpu = true; +% im = imread('../../examples/images/cat.jpg'); +% scores = matcaffe_demo_vgg(im, use_gpu, model_def_file, model_file, mean_file); +% +% NOTES +% the image crops are prepared as described in the paper (the aspect ratio is preserved) +% +% PREREQUISITES +% You may need to do the following before you start matlab: +% $ export LD_LIBRARY_PATH=/opt/intel/mkl/lib/intel64:/usr/local/cuda/lib64 +% $ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libstdc++.so.6 +% Or the equivalent based on where things are installed on your system + +% init caffe network (spews logging info) +matcaffe_init(use_gpu, model_def_file, model_file); + +% prepare oversampled input +% input_data is Height x Width x Channel x Num +tic; +input_data = {prepare_image(im, mean_file)}; +toc; + +% do forward pass to get scores +% scores are now Width x Height x Channels x Num +tic; +scores = caffe('forward', input_data); +toc; + +scores = scores{1}; +% size(scores) +scores = squeeze(scores); +% scores = mean(scores,2); + +% [~,maxlabel] = max(scores); + +% ------------------------------------------------------------------------ +function images = prepare_image(im, mean_file) +% ------------------------------------------------------------------------ +IMAGE_DIM = 256; +CROPPED_DIM = 224; + +d = load(mean_file); +IMAGE_MEAN = d.image_mean; + +% resize to fixed input size +im = single(im); + +if size(im, 1) < size(im, 2) + im = imresize(im, [IMAGE_DIM NaN]); +else + im = imresize(im, [NaN IMAGE_DIM]); +end + +% RGB -> BGR +im = im(:, :, [3 2 1]); + +% oversample (4 corners, center, and their x-axis flips) +images = zeros(CROPPED_DIM, CROPPED_DIM, 3, 10, 'single'); + +indices_y = [0 size(im,1)-CROPPED_DIM] + 1; +indices_x = [0 size(im,2)-CROPPED_DIM] + 1; +center_y = floor(indices_y(2) / 2)+1; +center_x = floor(indices_x(2) / 2)+1; + +curr = 1; +for i = indices_y + for j = indices_x + images(:, :, :, curr) = ... + permute(im(i:i+CROPPED_DIM-1, j:j+CROPPED_DIM-1, :)-IMAGE_MEAN, [2 1 3]); + images(:, :, :, curr+5) = images(end:-1:1, :, :, curr); + curr = curr + 1; + end +end +images(:,:,:,5) = ... + permute(im(center_y:center_y+CROPPED_DIM-1,center_x:center_x+CROPPED_DIM-1,:)-IMAGE_MEAN, ... + [2 1 3]); +images(:,:,:,10) = images(end:-1:1, :, :, curr); diff --git a/matlab/caffe/matcaffe_demo_vgg_mean_pix.m b/matlab/caffe/matcaffe_demo_vgg_mean_pix.m new file mode 100644 index 00000000000..5f7898a7029 --- /dev/null +++ b/matlab/caffe/matcaffe_demo_vgg_mean_pix.m @@ -0,0 +1,102 @@ +function scores = matcaffe_demo_vgg_mean_pix(im, use_gpu, model_def_file, model_file) +% scores = matcaffe_demo_vgg(im, use_gpu, model_def_file, model_file) +% +% Demo of the matlab wrapper based on the networks used for the "VGG" entry +% in the ILSVRC-2014 competition and described in the tech. report +% "Very Deep Convolutional Networks for Large-Scale Image Recognition" +% http://arxiv.org/abs/1409.1556/ +% +% INPUT +% im - color image as uint8 HxWx3 +% use_gpu - 1 to use the GPU, 0 to use the CPU +% model_def_file - network configuration (.prototxt file) +% model_file - network weights (.caffemodel file) +% +% OUTPUT +% scores 1000-dimensional ILSVRC score vector +% +% EXAMPLE USAGE +% model_def_file = 'zoo/deploy.prototxt'; +% model_file = 'zoo/model.caffemodel'; +% use_gpu = true; +% im = imread('../../examples/images/cat.jpg'); +% scores = matcaffe_demo_vgg(im, use_gpu, model_def_file, model_file); +% +% NOTES +% mean pixel subtraction is used instead of the mean image subtraction +% +% PREREQUISITES +% You may need to do the following before you start matlab: +% $ export LD_LIBRARY_PATH=/opt/intel/mkl/lib/intel64:/usr/local/cuda/lib64 +% $ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libstdc++.so.6 +% Or the equivalent based on where things are installed on your system + +% init caffe network (spews logging info) +matcaffe_init(use_gpu, model_def_file, model_file); + +% mean BGR pixel +mean_pix = [103.939, 116.779, 123.68]; + +% prepare oversampled input +% input_data is Height x Width x Channel x Num +tic; +input_data = {prepare_image(im, mean_pix)}; +toc; + +% do forward pass to get scores +% scores are now Width x Height x Channels x Num +tic; +scores = caffe('forward', input_data); +toc; + +scores = scores{1}; +% size(scores) +scores = squeeze(scores); +% scores = mean(scores,2); + +% [~,maxlabel] = max(scores); + +% ------------------------------------------------------------------------ +function images = prepare_image(im, mean_pix) +% ------------------------------------------------------------------------ +IMAGE_DIM = 256; +CROPPED_DIM = 224; + +% resize to fixed input size +im = single(im); + +if size(im, 1) < size(im, 2) + im = imresize(im, [IMAGE_DIM NaN]); +else + im = imresize(im, [NaN IMAGE_DIM]); +end + +% RGB -> BGR +im = im(:, :, [3 2 1]); + +% oversample (4 corners, center, and their x-axis flips) +images = zeros(CROPPED_DIM, CROPPED_DIM, 3, 10, 'single'); + +indices_y = [0 size(im,1)-CROPPED_DIM] + 1; +indices_x = [0 size(im,2)-CROPPED_DIM] + 1; +center_y = floor(indices_y(2) / 2)+1; +center_x = floor(indices_x(2) / 2)+1; + +curr = 1; +for i = indices_y + for j = indices_x + images(:, :, :, curr) = ... + permute(im(i:i+CROPPED_DIM-1, j:j+CROPPED_DIM-1, :), [2 1 3]); + images(:, :, :, curr+5) = images(end:-1:1, :, :, curr); + curr = curr + 1; + end +end +images(:,:,:,5) = ... + permute(im(center_y:center_y+CROPPED_DIM-1,center_x:center_x+CROPPED_DIM-1,:), ... + [2 1 3]); +images(:,:,:,10) = images(end:-1:1, :, :, curr); + +% mean BGR pixel subtraction +for c = 1:3 + images(:, :, c, :) = images(:, :, c, :) - mean_pix(c); +end diff --git a/src/caffe/layers/lrn_layer.cpp b/src/caffe/layers/lrn_layer.cpp index fb74b03dc88..a09a47940d2 100644 --- a/src/caffe/layers/lrn_layer.cpp +++ b/src/caffe/layers/lrn_layer.cpp @@ -14,6 +14,7 @@ void LRNLayer::LayerSetUp(const vector*>& bottom, pre_pad_ = (size_ - 1) / 2; alpha_ = this->layer_param_.lrn_param().alpha(); beta_ = this->layer_param_.lrn_param().beta(); + k_ = this->layer_param_.lrn_param().k(); if (this->layer_param_.lrn_param().norm_region() == LRNParameter_NormRegion_WITHIN_CHANNEL) { // Set up split_layer_ to use inputs in the numerator and denominator. @@ -110,7 +111,7 @@ void LRNLayer::CrossChannelForward_cpu( Dtype* scale_data = scale_.mutable_cpu_data(); // start with the constant value for (int i = 0; i < scale_.count(); ++i) { - scale_data[i] = 1.; + scale_data[i] = k_; } Blob padded_square(1, channels_ + size_ - 1, height_, width_); Dtype* padded_square_data = padded_square.mutable_cpu_data(); diff --git a/src/caffe/layers/lrn_layer.cu b/src/caffe/layers/lrn_layer.cu index ee5e359ff0b..47b003bc4aa 100644 --- a/src/caffe/layers/lrn_layer.cu +++ b/src/caffe/layers/lrn_layer.cu @@ -10,7 +10,7 @@ template __global__ void LRNFillScale(const int nthreads, const Dtype* in, const int num, const int channels, const int height, const int width, const int size, const Dtype alpha_over_size, - Dtype* scale) { + const Dtype k, Dtype* scale) { CUDA_KERNEL_LOOP(index, nthreads) { // find out the local offset int w = index % width; @@ -33,20 +33,20 @@ __global__ void LRNFillScale(const int nthreads, const Dtype* in, // until we reach size, nothing needs to be subtracted while (head < size) { accum_scale += in[head * step] * in[head * step]; - scale[(head - post_pad) * step] = 1. + accum_scale * alpha_over_size; + scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size; ++head; } // both add and subtract while (head < channels) { accum_scale += in[head * step] * in[head * step]; accum_scale -= in[(head - size) * step] * in[(head - size) * step]; - scale[(head - post_pad) * step] = 1. + accum_scale * alpha_over_size; + scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size; ++head; } // subtract only while (head < channels + post_pad) { accum_scale -= in[(head - size) * step] * in[(head - size) * step]; - scale[(head - post_pad) * step] = 1. + accum_scale * alpha_over_size; + scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size; ++head; } } @@ -90,7 +90,7 @@ void LRNLayer::CrossChannelForward_gpu( // NOLINT_NEXT_LINE(whitespace/operators) LRNFillScale<<>>( n_threads, bottom_data, num_, channels_, height_, width_, size_, - alpha_ / size_, scale_data); + alpha_ / size_, k_, scale_data); CUDA_POST_KERNEL_CHECK; n_threads = bottom[0]->count(); // NOLINT_NEXT_LINE(whitespace/operators) diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 9395c38f3e9..01a516ee3b9 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -548,6 +548,7 @@ message LRNParameter { WITHIN_CHANNEL = 1; } optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS]; + optional float k = 5 [default = 1.]; } // Message that stores parameters used by MemoryDataLayer @@ -715,6 +716,7 @@ message V0LayerParameter { optional uint32 local_size = 13 [default = 5]; // for local response norm optional float alpha = 14 [default = 1.]; // for local response norm optional float beta = 15 [default = 0.75]; // for local response norm + optional float k = 22 [default = 1.]; // For data layers, specify the data source optional string source = 16; diff --git a/src/caffe/util/upgrade_proto.cpp b/src/caffe/util/upgrade_proto.cpp index c69c58eb340..cbd6003c948 100644 --- a/src/caffe/util/upgrade_proto.cpp +++ b/src/caffe/util/upgrade_proto.cpp @@ -285,6 +285,14 @@ bool UpgradeLayerParameter(const LayerParameter& v0_layer_connection, is_fully_compatible = false; } } + if (v0_layer_param.has_k()) { + if (type == "lrn") { + layer_param->mutable_lrn_param()->set_k(v0_layer_param.k()); + } else { + LOG(ERROR) << "Unknown parameter k for layer type " << type; + is_fully_compatible = false; + } + } if (v0_layer_param.has_source()) { if (type == "data") { layer_param->mutable_data_param()->set_source(v0_layer_param.source());