diff --git a/include/caffe/loss_layers.hpp b/include/caffe/loss_layers.hpp index 62d6df71a4a..da2c4be1e61 100644 --- a/include/caffe/loss_layers.hpp +++ b/include/caffe/loss_layers.hpp @@ -78,6 +78,7 @@ class AccuracyLayer : public Layer { } } + int label_axis_, outer_num_, inner_num_; int top_k_; }; diff --git a/src/caffe/layers/accuracy_layer.cpp b/src/caffe/layers/accuracy_layer.cpp index 186f9f8632c..839f0e58ed4 100644 --- a/src/caffe/layers/accuracy_layer.cpp +++ b/src/caffe/layers/accuracy_layer.cpp @@ -21,11 +21,15 @@ void AccuracyLayer::Reshape( const vector*>& bottom, const vector*>& top) { CHECK_LE(top_k_, bottom[0]->count() / bottom[1]->count()) << "top_k must be less than or equal to the number of classes."; - CHECK_GE(bottom[0]->num_axes(), bottom[1]->num_axes()); - for (int i = 0; i < bottom[1]->num_axes(); ++i) { - CHECK_LE(bottom[0]->shape(i), bottom[1]->shape(i)) - << "Dimension mismatch between predictions and label."; - } + label_axis_ = + bottom[0]->CanonicalAxisIndex(this->layer_param_.accuracy_param().axis()); + outer_num_ = bottom[0]->count(0, label_axis_); + inner_num_ = bottom[0]->count(label_axis_ + 1); + CHECK_EQ(outer_num_ * inner_num_, bottom[1]->count()) + << "Number of labels must match number of predictions; " + << "e.g., if label axis == 1 and prediction shape is (N, C, H, W), " + << "label count (number of labels) must be N*H*W, " + << "with integer values in {0, 1, ..., C-1}."; vector top_shape(0); // Accuracy is a scalar; 0 axes. top[0]->Reshape(top_shape); } @@ -35,32 +39,35 @@ void AccuracyLayer::Forward_cpu(const vector*>& bottom, const vector*>& top) { Dtype accuracy = 0; const Dtype* bottom_data = bottom[0]->cpu_data(); - const Dtype* bottom_label = bottom[1]->cpu_data(); - int num = bottom[0]->count(0, bottom[1]->num_axes()); - int dim = bottom[0]->count() / num; + const Dtype* label = bottom[1]->cpu_data(); + const int dim = bottom[0]->count() / outer_num_; + const int num_labels = bottom[0]->shape(label_axis_); vector maxval(top_k_+1); vector max_id(top_k_+1); - for (int i = 0; i < num; ++i) { - // Top-k accuracy - std::vector > bottom_data_vector; - for (int j = 0; j < dim; ++j) { - bottom_data_vector.push_back( - std::make_pair(bottom_data[i * dim + j], j)); - } - std::partial_sort( - bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_, - bottom_data_vector.end(), std::greater >()); - // check if true label is in top k predictions - for (int k = 0; k < top_k_; k++) { - if (bottom_data_vector[k].second == static_cast(bottom_label[i])) { - ++accuracy; - break; + for (int i = 0; i < outer_num_; ++i) { + for (int j = 0; j < inner_num_; ++j) { + // Top-k accuracy + std::vector > bottom_data_vector; + for (int k = 0; k < num_labels; ++k) { + bottom_data_vector.push_back(std::make_pair( + bottom_data[i * dim + k * inner_num_ + j], k)); + } + std::partial_sort( + bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_, + bottom_data_vector.end(), std::greater >()); + // check if true label is in top k predictions + const int label_value = static_cast(label[i * inner_num_ + j]); + for (int k = 0; k < top_k_; k++) { + if (bottom_data_vector[k].second == label_value) { + ++accuracy; + break; + } } } } // LOG(INFO) << "Accuracy: " << accuracy; - top[0]->mutable_cpu_data()[0] = accuracy / num; + top[0]->mutable_cpu_data()[0] = accuracy / outer_num_ / inner_num_; // Accuracy layer should not be used as a loss function. } diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 3b4794664b5..7792695abfa 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -367,6 +367,13 @@ message AccuracyParameter { // the top k scoring classes. By default, only compare to the top scoring // class (i.e. argmax). optional uint32 top_k = 1 [default = 1]; + + // The "label" axis of the prediction blob, whose argmax corresponds to the + // predicted label -- may be negative to index from the end (e.g., -1 for the + // last axis). For example, if axis == 1 and the predictions are + // (N x C x H x W), the label blob is expected to contain N*H*W ground truth + // labels with integer values in {0, 1, ..., C-1}. + optional int32 axis = 2 [default = 1]; } // Message that stores parameters used by ArgMaxLayer