Skip to content

Commit

Permalink
Add ScalarLayer to multiply two Blobs, broadcasting the shape of the
Browse files Browse the repository at this point in the history
second as needed
  • Loading branch information
jeffdonahue committed Sep 3, 2015
1 parent 66823b5 commit 3abc491
Show file tree
Hide file tree
Showing 5 changed files with 531 additions and 1 deletion.
53 changes: 53 additions & 0 deletions include/caffe/common_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,59 @@ class SilenceLayer : public Layer<Dtype> {
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
};

/**
* @brief Computes a product of two input Blobs, with the shape of the
* latter Blob "broadcast" to match the shape of the former.
* Equivalent to tiling the latter Blob, then computing the elementwise
* product.
*/
template <typename Dtype>
class ScalarLayer: public Layer<Dtype> {
public:
explicit ScalarLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual inline const char* type() const { return "Scalar"; }
virtual inline int ExactNumBottomBlobs() const { return 2; }
virtual inline int ExactNumTopBlobs() const { return 1; }

protected:
/**
* In the below shape specifications, @f$ i @f$ denotes the value of the
* `axis` field given by `this->layer_param_.scalar_param().axis()`, after
* canonicalization (i.e., conversion from negative to positive index,
* if applicable).
*
* @param bottom input Blob vector (length 2)
* -# @f$ (d_0 \times ... \times
* d_i \times ... \times d_j \times ... \times d_n) @f$
* the first factor @f$ x @f$
* -# @f$ (d_i \times ... \times d_j) @f$
* the second factor @f$ y @f$
* @param top output Blob vector (length 1)
* -# @f$ (d_0 \times ... \times
* d_i \times ... \times d_j \times ... \times d_n) @f$
* the product @f$ z = x y @f$ computed after "broadcasting" y.
* Equivalent to tiling @f$ y @f$ to have the same shape as @f$ x @f$,
* then computing the elementwise product.
*/
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

Blob<Dtype> sum_multiplier_;
Blob<Dtype> sum_result_;
int axis_;
int outer_dim_, scalar_dim_, inner_dim_;
};

/**
* @brief Computes the softmax function.
*
Expand Down
115 changes: 115 additions & 0 deletions src/caffe/layers/scalar_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#include <algorithm>
#include <vector>

#include "caffe/common_layers.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

template <typename Dtype>
void ScalarLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
// TODO: make ScalarLayer usable in-place.
// Currently, in-place computation is broken during Backward with
// propagate_down[0] && propagate_down[1], as bottom[0]'s diff is used for
// temporary storage of an intermediate result, overwriting top[0]'s diff
// if using in-place computation.
CHECK_NE(bottom[0], top[0]) << "ScalarLayer cannot be used in-place";
axis_ =
bottom[0]->CanonicalAxisIndex(this->layer_param_.scalar_param().axis());
CHECK_GE(bottom[0]->num_axes(), axis_ + bottom[1]->num_axes())
<< "bottom[1]'s shape extends past bottom[0]'s shape when applied "
<< "starting with bottom[0] axis = " << axis_;
for (int i = 0; i < bottom[1]->num_axes(); ++i) {
CHECK_EQ(bottom[0]->shape(axis_ + i), bottom[1]->shape(i))
<< "dimension mismatch between bottom[0]->shape(" << axis_ + i
<< ") and bottom[1]->shape(" << i << ")";
}
outer_dim_ = bottom[0]->count(0, axis_);
scalar_dim_ = bottom[1]->count();
inner_dim_ = bottom[0]->count(axis_ + bottom[1]->num_axes());
top[0]->ReshapeLike(*bottom[0]);
sum_result_.Reshape(vector<int>(1, outer_dim_ * scalar_dim_));
const int sum_mult_size = std::max(outer_dim_, inner_dim_);
sum_multiplier_.Reshape(vector<int>(1, sum_mult_size));
if (sum_multiplier_.cpu_data()[sum_mult_size - 1] != Dtype(1)) {
caffe_set(sum_mult_size, Dtype(1), sum_multiplier_.mutable_cpu_data());
}
}

template <typename Dtype>
void ScalarLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
const Dtype* scalar_data = bottom[1]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
for (int n = 0; n < outer_dim_; ++n) {
for (int d = 0; d < scalar_dim_; ++d) {
const Dtype factor = scalar_data[d];
caffe_cpu_scale(inner_dim_, factor, bottom_data, top_data);
bottom_data += inner_dim_;
top_data += inner_dim_;
}
}
}

template <typename Dtype>
void ScalarLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
if (propagate_down[1]) {
const Dtype* top_diff = top[0]->cpu_diff();
const Dtype* bottom_data = bottom[0]->cpu_data();
// Hack: store big eltwise product in bottom[0] diff, except in the special
// case where this layer itself does the eltwise product, in which case we
// can store it directly in the scalar diff, and we're done.
const bool is_eltwise = (inner_dim_ == 1 && outer_dim_ == 1);
Dtype* product = is_eltwise ?
bottom[1]->mutable_cpu_diff() : bottom[0]->mutable_cpu_diff();
caffe_mul(top[0]->count(), top_diff, bottom_data, product);
if (!is_eltwise) {
Dtype* sum_result = NULL;
if (inner_dim_ == 1) {
sum_result = product;
} else if (sum_result_.count() == 1) {
const Dtype* sum_mult = sum_multiplier_.cpu_data();
Dtype* scalar_diff = bottom[1]->mutable_cpu_diff();
*scalar_diff = caffe_cpu_dot(inner_dim_, product, sum_mult);
} else {
const Dtype* sum_mult = sum_multiplier_.cpu_data();
sum_result = (outer_dim_ == 1) ?
bottom[1]->mutable_cpu_diff() : sum_result_.mutable_cpu_data();
caffe_cpu_gemv(CblasNoTrans, sum_result_.count(), inner_dim_,
Dtype(1), product, sum_mult, Dtype(0), sum_result);
}
if (outer_dim_ != 1) {
const Dtype* sum_mult = sum_multiplier_.cpu_data();
Dtype* scalar_diff = bottom[1]->mutable_cpu_diff();
if (scalar_dim_ == 1) {
*scalar_diff = caffe_cpu_dot(outer_dim_, sum_mult, sum_result);
} else {
caffe_cpu_gemv(CblasTrans, outer_dim_, scalar_dim_,
Dtype(1), sum_result, sum_mult, Dtype(0), scalar_diff);
}
}
}
}
if (propagate_down[0]) {
const Dtype* top_diff = top[0]->cpu_diff();
const Dtype* scalar_data = bottom[1]->cpu_data();
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
for (int n = 0; n < outer_dim_; ++n) {
for (int d = 0; d < scalar_dim_; ++d) {
const Dtype factor = scalar_data[d];
caffe_cpu_scale(inner_dim_, factor, top_diff, bottom_diff);
bottom_diff += inner_dim_;
top_diff += inner_dim_;
}
}
}
}

INSTANTIATE_CLASS(ScalarLayer);
REGISTER_LAYER_CLASS(Scalar);

} // namespace caffe
86 changes: 86 additions & 0 deletions src/caffe/layers/scalar_layer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include <cfloat>
#include <vector>

#include "caffe/common_layers.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

template <typename Dtype>
__global__ void ScalarForward(const int n, const Dtype* in,
const Dtype* scalars, const int scalar_dim, const int inner_dim,
Dtype* out) {
CUDA_KERNEL_LOOP(index, n) {
const int scalar_index = (index / inner_dim) % scalar_dim;
out[index] = in[index] * scalars[scalar_index];
}
}

template <typename Dtype>
void ScalarLayer<Dtype>::Forward_gpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
const int count = top[0]->count();
const Dtype* bottom_data = bottom[0]->gpu_data();
const Dtype* scalar_data = bottom[1]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
ScalarForward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
<<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, scalar_data, scalar_dim_, inner_dim_, top_data);
}

template <typename Dtype>
void ScalarLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
if (propagate_down[1]) {
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
// Hack: store big eltwise product in bottom[0] diff, except in the special
// case where this layer itself does the eltwise product, in which case we
// can store it directly in the scalar diff, and we're done.
const bool is_eltwise = (inner_dim_ == 1 && outer_dim_ == 1);
Dtype* product = is_eltwise ?
bottom[1]->mutable_gpu_diff() : bottom[0]->mutable_gpu_diff();
caffe_gpu_mul(top[0]->count(), top_diff, bottom_data, product);
if (!is_eltwise) {
Dtype* sum_result = NULL;
if (inner_dim_ == 1) {
sum_result = product;
} else if (sum_result_.count() == 1) {
const Dtype* sum_mult = sum_multiplier_.gpu_data();
Dtype* scalar_diff = bottom[1]->mutable_cpu_diff();
caffe_gpu_dot(inner_dim_, product, sum_mult, scalar_diff);
} else {
const Dtype* sum_mult = sum_multiplier_.gpu_data();
sum_result = (outer_dim_ == 1) ?
bottom[1]->mutable_gpu_diff() : sum_result_.mutable_gpu_data();
caffe_gpu_gemv(CblasNoTrans, sum_result_.count(), inner_dim_,
Dtype(1), product, sum_mult, Dtype(0), sum_result);
}
if (outer_dim_ != 1) {
const Dtype* sum_mult = sum_multiplier_.gpu_data();
if (scalar_dim_ == 1) {
Dtype* scalar_diff = bottom[1]->mutable_cpu_diff();
caffe_gpu_dot(outer_dim_, sum_mult, sum_result, scalar_diff);
} else {
Dtype* scalar_diff = bottom[1]->mutable_gpu_diff();
caffe_gpu_gemv(CblasTrans, outer_dim_, scalar_dim_,
Dtype(1), sum_result, sum_mult, Dtype(0), scalar_diff);
}
}
}
}
if (propagate_down[0]) {
const int count = top[0]->count();
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* scalar_data = bottom[1]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
ScalarForward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
<<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, scalar_data, scalar_dim_, inner_dim_, bottom_diff);
}
}

INSTANTIATE_LAYER_GPU_FUNCS(ScalarLayer);

} // namespace caffe
20 changes: 19 additions & 1 deletion src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ message ParamSpec {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
// LayerParameter next available layer-specific ID: 139 (last added: tile_param)
// LayerParameter next available layer-specific ID: 140 (last added: scalar_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
Expand Down Expand Up @@ -377,6 +377,7 @@ message LayerParameter {
optional ReductionParameter reduction_param = 136;
optional ReLUParameter relu_param = 123;
optional ReshapeParameter reshape_param = 133;
optional ScalarParameter scalar_param = 139;
optional SigmoidParameter sigmoid_param = 124;
optional SoftmaxParameter softmax_param = 125;
optional SPPParameter spp_param = 132;
Expand Down Expand Up @@ -876,6 +877,23 @@ message ReshapeParameter {
optional int32 num_axes = 3 [default = -1];
}

message ScalarParameter {
// The first axis of bottom[0] (the first input Blob) along which to apply
// bottom[1] (the second input Blob). May be negative to index from the end
// (e.g., -1 for the last axis).
//
// For example, if bottom[0] is 4D with shape 100x3x224x224, the output
// top[0] will have the same shape, and bottom[1] may have any of the
// following shapes (for the given value of axis):
// (axis == 0 == -4) 100; 100x3; 100x3x224; 100x3x224x224
// (axis == 1 == -3) 3; 3x224; 3x224x224
// (axis == 2 == -2) 224; 224x224
// (axis == 3 == -1) 224
// Furthermore, bottom[1] may have the empty shape (regardless of the value of
// "axis") -- a literal scalar.
optional int32 axis = 1 [default = 0];
}

message SigmoidParameter {
enum Engine {
DEFAULT = 0;
Expand Down
Loading

0 comments on commit 3abc491

Please sign in to comment.