From 4d988ed28ec26702fcd555f42aa336dbecda6423 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 12 Sep 2017 09:45:15 +0800 Subject: [PATCH 1/7] add auc_op --- paddle/operators/auc_op.cc | 80 ++++++++++++++++++++++ paddle/operators/auc_op.h | 132 +++++++++++++++++++++++++++++++++++++ 2 files changed, 212 insertions(+) create mode 100644 paddle/operators/auc_op.cc create mode 100644 paddle/operators/auc_op.h diff --git a/paddle/operators/auc_op.cc b/paddle/operators/auc_op.cc new file mode 100644 index 0000000000000..fa18d6ca0d220 --- /dev/null +++ b/paddle/operators/auc_op.cc @@ -0,0 +1,80 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/auc_op.h" + +namespace paddle { +namespace operators { + +class AccuracyOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Inference"), + "Input of Inference must be initialized."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input of Inference must be initialized."); + auto *inference = ctx.Input("Inference"); + auto *inference_prob = ctx.Input("InferenceProb"); + auto *label = ctx.Input("Label"); + + PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label must be a vector"); + PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0], + "inference size must be the same as label size"); + PADDLE_ENFORCE_EQ(inference->dims(), inference_prob->dims()); + + ctx.Output("Accuracy")->Resize({1}); + } +}; + +class AucOpMaker : public framework::OpProtoAndCheckerMaker { + public: + AucOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Inference", + "Topk(indices) the network output, float value indicating " + "probabilities of classification"); + AddInput("InferenceProb", + "Topk(values) the network output, float value indicating " + "probabilities of classification"); + AddInput("Label", "Label of the training data"); + // TODO(typhoonzero): support weight + AddOutput("AUC", "Area Under Curve caculations"); + AddAttr("curve", "Possible curves are ROC and PR") + .SetDefault("ROC"); + AddAttr("num_thresholds", + "The number of thresholds to use when discretizing the" + " roc curve.") + .SetDefault(200); + + AddComment( + R"DOC(Computes the AUC according forward output and label. + You can find the definations here: + https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve + + Possible curves are: + ROC: Receiver operating characteristic + PR: Precision Recall + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(auc, ops::AccuracyOp, ops::AccuracyOpMaker); +REGISTER_OP_CPU_KERNEL(auc, ops::AucKernel); diff --git a/paddle/operators/auc_op.h b/paddle/operators/auc_op.h new file mode 100644 index 0000000000000..d4f40cd79c664 --- /dev/null +++ b/paddle/operators/auc_op.h @@ -0,0 +1,132 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class AccuracyKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* inference = ctx.Input("Inference"); + auto* inference_prob = ctx.Input("InferenceProb"); + auto* label = ctx.Input("Label"); + auto* auc = ctx.Output("AUC"); + + float* auc_data = auc->mutable_data(ctx.GetPlace()); + + std::string curve = ctx.Attr("curve"); + int num_thresholds = ctx.Attr("num_thresholds"); + std::vector thresholds_list; + thresholds_list.reserve(num_thresholds); + for (int i = 1; i < num_thresholds - 1; i++) { + thresholds_list[i] = (float)i / (num_thresholds - 1); + } + const float kEpsilon = 1e-7; + thresholds_list[0] = 0.0f - kEpsilon; + thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon; + + const int* inference_data = inference->data(); + const T* inference_prob_data = inference->data(); + const T* label_data = label->data(); + + size_t num_samples = inference->dims()[0]; + size_t class_dim = inference->dims()[1]; + + // create local tensor for storing the curve: TP, FN, TN, FP + // TODO(typhoonzero): put these tensors in Scope + // TODO(typhoonzero): use op to caculate these values. + Tensor true_positive, false_positeve, true_negative, false_negative; + + true_positive.Resize({num_thresholds}); + false_negative.Resize({num_thresholds}); + true_negative.Resize({num_thresholds}); + false_positive.Resize({num_thresholds}); + + int* tp_data = true_positive.mutable_data(); + int* fn_data = false_negative.mutable_data(); + int* tn_data = true_negative.mutable_data(); + int* fp_data = false_positive.mutable_data(); + + for (auto thresh = thresholds_list.begin(); thresh != thresholds_list.end(); + thresh++) { + size_t idx_thresh = thresh - thresholds_list.begin(); + // caculate TP, FN, TN, FP for current thresh + int tp, fn, tn, fp = 0; + for (size_t i = 0; i < num_samples; i++) { + for (size_t j = 0; j < class_dim; j++) { + if (inference_data[i * class_dim + j] == label_data[i]) { + if (inference_prob_data[i * class_dim + j] >= (*thresh)) { + tp++; + } else { + tn++; + } + } else { + if (inference_prob_data[i * class_dim + j] >= (*thresh)) { + fp++; + } else { + fn++; + } + } + } + } + // store rates + tp_data[idx_thresh] = tp; + fn_data[idx_thresh] = fn; + tn_data[idx_thresh] = tn; + fp_data[idx_thresh] = fp; + } + // epsilon to avoid divide by zero. + float epsilon = 1e-6; + // Riemann sum to caculate auc. + Tensor tp_rate, fp_rate, rec_rate; + tp_rate.Resize({num_thresholds}); + fp_rate.Resize({num_thresholds}); + rec_rate.Resize({num_thresholds}); + float* tp_rate_data = tp_rate.mutable_data(); + float* fp_rate_data = fp_rate.mutable_data(); + float* rec_rate_data = rec_rate.mutable_data(); + for (int i = 0; i < num_thresholds; i++) { + tp_rate_data[i] = ((float)tp_data[i + epsilon) / (tp_data[i] + fn_data[i] + epsilon); + fp_rate_data[i] = + (float)fp_data[i] / (fp_data[i] + tn_data[i] + epsilon); + rec_rate_data[i] = + ((float)tp_data[i] + epsilon) / (tp_data[i] + fp_data[i] + epsilon); + } + + if (curve == "ROC") { + for (int i = 1; i < num_thresholds; i++) { + auto dx = fp_rate_data[i] - fp_rate_data[i - 1]; + auto y = (tp_rate_data[i] + tp_rate_data[i - 1]) / 2.0f; + *auc_data = *auc_data + dx * y; + } + } else if (curve = "PR") { + for (int i = 1; i < num_thresholds; i++) { + auto dx = tp_rate_data[i] - tp_rate_data[i - 1]; + auto y = (rec_rate_data[i] + rec_rate_data[i - 1]) / 2.0f; + *auc_data = *auc_data + dx * y; + } + } + } +}; + +} // namespace operators +} // namespace paddle From d1e6d5522a437ae592e8a2e2126e6ff50d9c7d08 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 12 Sep 2017 21:03:55 +0800 Subject: [PATCH 2/7] update --- paddle/operators/auc_op.cc | 4 ++-- paddle/operators/auc_op.h | 32 ++++++++++++++++---------------- paddle/pybind/pybind.cc | 1 + 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/paddle/operators/auc_op.cc b/paddle/operators/auc_op.cc index fa18d6ca0d220..3a43f9bcc48df 100644 --- a/paddle/operators/auc_op.cc +++ b/paddle/operators/auc_op.cc @@ -17,7 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { -class AccuracyOp : public framework::OperatorWithKernel { +class AucOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -76,5 +76,5 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(auc, ops::AccuracyOp, ops::AccuracyOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(auc, ops::AucOp, ops::AucOpMaker); REGISTER_OP_CPU_KERNEL(auc, ops::AucKernel); diff --git a/paddle/operators/auc_op.h b/paddle/operators/auc_op.h index d4f40cd79c664..fd110c06e64f9 100644 --- a/paddle/operators/auc_op.h +++ b/paddle/operators/auc_op.h @@ -23,7 +23,7 @@ namespace operators { using Tensor = framework::Tensor; template -class AccuracyKernel : public framework::OpKernel { +class AucKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* inference = ctx.Input("Inference"); @@ -45,7 +45,7 @@ class AccuracyKernel : public framework::OpKernel { thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon; const int* inference_data = inference->data(); - const T* inference_prob_data = inference->data(); + const T* inference_prob_data = inference_prob->data(); const T* label_data = label->data(); size_t num_samples = inference->dims()[0]; @@ -54,17 +54,17 @@ class AccuracyKernel : public framework::OpKernel { // create local tensor for storing the curve: TP, FN, TN, FP // TODO(typhoonzero): put these tensors in Scope // TODO(typhoonzero): use op to caculate these values. - Tensor true_positive, false_positeve, true_negative, false_negative; + Tensor true_positive, false_positive, true_negative, false_negative; true_positive.Resize({num_thresholds}); false_negative.Resize({num_thresholds}); true_negative.Resize({num_thresholds}); false_positive.Resize({num_thresholds}); - int* tp_data = true_positive.mutable_data(); - int* fn_data = false_negative.mutable_data(); - int* tn_data = true_negative.mutable_data(); - int* fp_data = false_positive.mutable_data(); + int* tp_data = true_positive.mutable_data(ctx.GetPlace()); + int* fn_data = false_negative.mutable_data(ctx.GetPlace()); + int* tn_data = true_negative.mutable_data(ctx.GetPlace()); + int* fp_data = false_positive.mutable_data(ctx.GetPlace()); for (auto thresh = thresholds_list.begin(); thresh != thresholds_list.end(); thresh++) { @@ -101,15 +101,15 @@ class AccuracyKernel : public framework::OpKernel { tp_rate.Resize({num_thresholds}); fp_rate.Resize({num_thresholds}); rec_rate.Resize({num_thresholds}); - float* tp_rate_data = tp_rate.mutable_data(); - float* fp_rate_data = fp_rate.mutable_data(); - float* rec_rate_data = rec_rate.mutable_data(); + float* tp_rate_data = tp_rate.mutable_data(ctx.GetPlace()); + float* fp_rate_data = fp_rate.mutable_data(ctx.GetPlace()); + float* rec_rate_data = rec_rate.mutable_data(ctx.GetPlace()); for (int i = 0; i < num_thresholds; i++) { - tp_rate_data[i] = ((float)tp_data[i + epsilon) / (tp_data[i] + fn_data[i] + epsilon); - fp_rate_data[i] = - (float)fp_data[i] / (fp_data[i] + tn_data[i] + epsilon); - rec_rate_data[i] = - ((float)tp_data[i] + epsilon) / (tp_data[i] + fp_data[i] + epsilon); + tp_rate_data[i] = + ((float)tp_data[i] + epsilon) / (tp_data[i] + fn_data[i] + epsilon); + fp_rate_data[i] = (float)fp_data[i] / (fp_data[i] + tn_data[i] + epsilon); + rec_rate_data[i] = + ((float)tp_data[i] + epsilon) / (tp_data[i] + fp_data[i] + epsilon); } if (curve == "ROC") { @@ -118,7 +118,7 @@ class AccuracyKernel : public framework::OpKernel { auto y = (tp_rate_data[i] + tp_rate_data[i - 1]) / 2.0f; *auc_data = *auc_data + dx * y; } - } else if (curve = "PR") { + } else if (curve == "PR") { for (int i = 1; i < num_thresholds; i++) { auto dx = tp_rate_data[i] - tp_rate_data[i - 1]; auto y = (rec_rate_data[i] + rec_rate_data[i - 1]) / 2.0f; diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 53985933ed143..a673b7d1a87ed 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -50,6 +50,7 @@ USE_OP(cos_sim); USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(scatter); USE_OP(top_k); +USE_CPU_ONLY_OP(auc); USE_OP(squared_l2_distance); namespace paddle { From 399a5eec69a34d6336858179080ae3e5dc67ee90 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 13 Sep 2017 12:45:23 +0800 Subject: [PATCH 3/7] auc_op --- paddle/operators/auc_op.cc | 34 ++++++++++++++-------------- paddle/operators/auc_op.h | 45 ++++++++++++++++++++++---------------- 2 files changed, 44 insertions(+), 35 deletions(-) diff --git a/paddle/operators/auc_op.cc b/paddle/operators/auc_op.cc index 3a43f9bcc48df..63f0d50fdc3ef 100644 --- a/paddle/operators/auc_op.cc +++ b/paddle/operators/auc_op.cc @@ -28,15 +28,12 @@ class AucOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), "Input of Inference must be initialized."); auto *inference = ctx.Input("Inference"); - auto *inference_prob = ctx.Input("InferenceProb"); auto *label = ctx.Input("Label"); - PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label must be a vector"); - PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0], - "inference size must be the same as label size"); - PADDLE_ENFORCE_EQ(inference->dims(), inference_prob->dims()); + PADDLE_ENFORCE_EQ(inference->dims(), label->dims(), + "inference should have same shape as label"); - ctx.Output("Accuracy")->Resize({1}); + ctx.Output("AUC")->Resize({1}); } }; @@ -45,14 +42,15 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { AucOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Inference", - "Topk(indices) the network output, float value indicating " - "probabilities of classification"); - AddInput("InferenceProb", - "Topk(values) the network output, float value indicating " - "probabilities of classification"); - AddInput("Label", "Label of the training data"); - // TODO(typhoonzero): support weight - AddOutput("AUC", "Area Under Curve caculations"); + "A floating point `Tensor` of arbitrary shape and whose values" + "are in the range `[0, 1]`."); + AddInput("Label", + "A `Tensor` whose shape matches " + "`Inference`. Will be cast to `bool`."); + // TODO(typhoonzero): support weight input + AddOutput("AUC", + "A scalar `Tensor` representing the " + "current area-under-curve."); AddAttr("curve", "Possible curves are ROC and PR") .SetDefault("ROC"); AddAttr("num_thresholds", @@ -62,12 +60,16 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { AddComment( R"DOC(Computes the AUC according forward output and label. + Best to use for binary classification evaluations. + If `label` can be values other than 0 and 1, it will be cast + to bool. + You can find the definations here: https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve Possible curves are: - ROC: Receiver operating characteristic - PR: Precision Recall + - ROC: Receiver operating characteristic + - PR: Precision Recall )DOC"); } }; diff --git a/paddle/operators/auc_op.h b/paddle/operators/auc_op.h index fd110c06e64f9..b6ca74f1af270 100644 --- a/paddle/operators/auc_op.h +++ b/paddle/operators/auc_op.h @@ -22,12 +22,15 @@ namespace operators { using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + template class AucKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* inference = ctx.Input("Inference"); - auto* inference_prob = ctx.Input("InferenceProb"); auto* label = ctx.Input("Label"); auto* auc = ctx.Output("AUC"); @@ -44,14 +47,20 @@ class AucKernel : public framework::OpKernel { thresholds_list[0] = 0.0f - kEpsilon; thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon; - const int* inference_data = inference->data(); - const T* inference_prob_data = inference_prob->data(); - const T* label_data = label->data(); + size_t num_samples = inference->numel(); + + const T* inference_data = inference->data(); + Tensor label_casted; + label_casted.Resize(label->dims()); + bool* label_casted_data = label_casted.mutable_data(ctx.GetPlace()); - size_t num_samples = inference->dims()[0]; - size_t class_dim = inference->dims()[1]; + const int* label_data = label->data(); + // cast label_data to bool + for (size_t i = 0; i < num_samples; i++) { + label_casted_data[i] = static_cast(label_data[i]); + } - // create local tensor for storing the curve: TP, FN, TN, FP + // Create local tensor for storing the curve: TP, FN, TN, FP // TODO(typhoonzero): put these tensors in Scope // TODO(typhoonzero): use op to caculate these values. Tensor true_positive, false_positive, true_negative, false_negative; @@ -72,19 +81,17 @@ class AucKernel : public framework::OpKernel { // caculate TP, FN, TN, FP for current thresh int tp, fn, tn, fp = 0; for (size_t i = 0; i < num_samples; i++) { - for (size_t j = 0; j < class_dim; j++) { - if (inference_data[i * class_dim + j] == label_data[i]) { - if (inference_prob_data[i * class_dim + j] >= (*thresh)) { - tp++; - } else { - tn++; - } + if (label_casted_data[i]) { + if (inference_data[i] >= (*thresh)) { + tp++; + } else { + tn++; + } + } else { + if (inference_data[i] >= (*thresh)) { + fp++; } else { - if (inference_prob_data[i * class_dim + j] >= (*thresh)) { - fp++; - } else { - fn++; - } + fn++; } } } From c7eef34c28353dc74a0042dcd2b35cb2d40598d5 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 13 Sep 2017 16:49:19 +0800 Subject: [PATCH 4/7] auc cpu only --- paddle/operators/auc_op.cc | 5 +- paddle/operators/auc_op.h | 24 ++++--- .../paddle/v2/framework/tests/test_auc_op.py | 66 +++++++++++++++++++ .../v2/framework/tests/test_top_k_op.py | 6 ++ 4 files changed, 86 insertions(+), 15 deletions(-) create mode 100644 python/paddle/v2/framework/tests/test_auc_op.py diff --git a/paddle/operators/auc_op.cc b/paddle/operators/auc_op.cc index 63f0d50fdc3ef..f88f722d6c2b1 100644 --- a/paddle/operators/auc_op.cc +++ b/paddle/operators/auc_op.cc @@ -31,9 +31,9 @@ class AucOp : public framework::OperatorWithKernel { auto *label = ctx.Input("Label"); PADDLE_ENFORCE_EQ(inference->dims(), label->dims(), - "inference should have same shape as label"); + "inference and label should have same shape"); - ctx.Output("AUC")->Resize({1}); + ctx.Output("AUC")->Resize({1}); } }; @@ -51,6 +51,7 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("AUC", "A scalar `Tensor` representing the " "current area-under-curve."); + AddAttr("curve", "Possible curves are ROC and PR") .SetDefault("ROC"); AddAttr("num_thresholds", diff --git a/paddle/operators/auc_op.h b/paddle/operators/auc_op.h index b6ca74f1af270..ad5585be30481 100644 --- a/paddle/operators/auc_op.h +++ b/paddle/operators/auc_op.h @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include +#include #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" @@ -75,23 +75,21 @@ class AucKernel : public framework::OpKernel { int* tn_data = true_negative.mutable_data(ctx.GetPlace()); int* fp_data = false_positive.mutable_data(ctx.GetPlace()); - for (auto thresh = thresholds_list.begin(); thresh != thresholds_list.end(); - thresh++) { - size_t idx_thresh = thresh - thresholds_list.begin(); + for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) { // caculate TP, FN, TN, FP for current thresh - int tp, fn, tn, fp = 0; + int tp = 0, fn = 0, tn = 0, fp = 0; for (size_t i = 0; i < num_samples; i++) { if (label_casted_data[i]) { - if (inference_data[i] >= (*thresh)) { + if (inference_data[i] >= (thresholds_list[idx_thresh])) { tp++; } else { - tn++; + fn++; } } else { - if (inference_data[i] >= (*thresh)) { + if (inference_data[i] >= (thresholds_list[idx_thresh])) { fp++; } else { - fn++; + tn++; } } } @@ -118,11 +116,11 @@ class AucKernel : public framework::OpKernel { rec_rate_data[i] = ((float)tp_data[i] + epsilon) / (tp_data[i] + fp_data[i] + epsilon); } - + *auc_data = 0.0f; if (curve == "ROC") { - for (int i = 1; i < num_thresholds; i++) { - auto dx = fp_rate_data[i] - fp_rate_data[i - 1]; - auto y = (tp_rate_data[i] + tp_rate_data[i - 1]) / 2.0f; + for (int i = 0; i < num_thresholds - 1; i++) { + auto dx = fp_rate_data[i] - fp_rate_data[i + 1]; + auto y = (tp_rate_data[i] + tp_rate_data[i + 1]) / 2.0f; *auc_data = *auc_data + dx * y; } } else if (curve == "PR") { diff --git a/python/paddle/v2/framework/tests/test_auc_op.py b/python/paddle/v2/framework/tests/test_auc_op.py new file mode 100644 index 0000000000000..f458e01fc567d --- /dev/null +++ b/python/paddle/v2/framework/tests/test_auc_op.py @@ -0,0 +1,66 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestAucOp(OpTest): + def setUp(self): + self.op_type = "auc" + pred = np.random.random((128)).astype("float32") + labels = np.random.randint(0, 2, (128, )) + num_thresholds = 200 + self.inputs = {'Inference': pred, 'Label': labels} + self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds} + # NOTE: sklearn use a different way to generate thresholds + # which will cause the result differs slightly: + # from sklearn.metrics import roc_curve, auc + # fpr, tpr, thresholds = roc_curve(labels, pred) + # auc_value = auc(fpr, tpr) + # we caculate AUC again using numpy for testing + kepsilon = 1e-7 # to account for floating point imprecisions + thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) + for i in range(num_thresholds - 2)] + thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] + + # caculate TP, FN, TN, FP count + tp_list = np.ndarray((num_thresholds, )) + fn_list = np.ndarray((num_thresholds, )) + tn_list = np.ndarray((num_thresholds, )) + fp_list = np.ndarray((num_thresholds, )) + for idx_thresh, thresh in enumerate(thresholds): + tp, fn, tn, fp = 0, 0, 0, 0 + for i, lbl in enumerate(labels): + if lbl: + if pred[i] >= thresh: + tp += 1 + else: + fn += 1 + else: + if pred[i] >= thresh: + fp += 1 + else: + tn += 1 + tp_list[idx_thresh] = tp + fn_list[idx_thresh] = fn + tn_list[idx_thresh] = tn + fp_list[idx_thresh] = fp + + epsilon = 1e-6 + tpr = (tp_list.astype("float32") + epsilon) / ( + tp_list + fn_list + epsilon) + fpr = fp_list.astype("float32") / (fp_list + tn_list + epsilon) + rec = (tp_list.astype("float32") + epsilon) / ( + tp_list + fp_list + epsilon) + + x = fpr[:num_thresholds - 1] - fpr[1:] + y = (tpr[:num_thresholds - 1] + tpr[1:]) / 2.0 + auc_value = np.sum(x * y) + + self.outputs = {'AUC': auc_value} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_top_k_op.py b/python/paddle/v2/framework/tests/test_top_k_op.py index cab799256d791..694f37d612d4c 100644 --- a/python/paddle/v2/framework/tests/test_top_k_op.py +++ b/python/paddle/v2/framework/tests/test_top_k_op.py @@ -21,6 +21,9 @@ def setUp(self): self.outputs = {'Out': output, 'Indices': indices} + def test_check_output(self): + self.check_output() + class TestTopkOp3d(OpTest): def setUp(self): @@ -42,6 +45,9 @@ def setUp(self): self.outputs = {'Out': output, 'Indices': indices} + def test_check_output(self): + self.check_output() + if __name__ == "__main__": unittest.main() From bf7bc1276fef28d5504c862982f86470cf87ea93 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 19 Sep 2017 20:50:38 +0800 Subject: [PATCH 5/7] update --- paddle/operators/auc_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/operators/auc_op.cc b/paddle/operators/auc_op.cc index f88f722d6c2b1..89f379b78f9e9 100644 --- a/paddle/operators/auc_op.cc +++ b/paddle/operators/auc_op.cc @@ -33,7 +33,7 @@ class AucOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(inference->dims(), label->dims(), "inference and label should have same shape"); - ctx.Output("AUC")->Resize({1}); + ctx.Output("AUC")->Resize({1}); } }; From 436b6acc6ffedb29bd84e4b5d8f7c332760ac1f2 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 20 Sep 2017 16:09:48 +0800 Subject: [PATCH 6/7] follow comments --- paddle/operators/auc_op.cc | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/paddle/operators/auc_op.cc b/paddle/operators/auc_op.cc index 89f379b78f9e9..e7275a5933c03 100644 --- a/paddle/operators/auc_op.cc +++ b/paddle/operators/auc_op.cc @@ -42,17 +42,17 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { AucOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Inference", - "A floating point `Tensor` of arbitrary shape and whose values" - "are in the range `[0, 1]`."); + "A floating point tensor of arbitrary shape and whose values" + "are in the range [0, 1]."); AddInput("Label", - "A `Tensor` whose shape matches " - "`Inference`. Will be cast to `bool`."); + "A tensor whose shape matches " + "Inference. Will be cast to bool."); // TODO(typhoonzero): support weight input AddOutput("AUC", - "A scalar `Tensor` representing the " + "A scalar representing the " "current area-under-curve."); - AddAttr("curve", "Possible curves are ROC and PR") + AddAttr("curve", "Curve type, can be 'ROC' or 'PR'.") .SetDefault("ROC"); AddAttr("num_thresholds", "The number of thresholds to use when discretizing the" @@ -62,7 +62,8 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { AddComment( R"DOC(Computes the AUC according forward output and label. Best to use for binary classification evaluations. - If `label` can be values other than 0 and 1, it will be cast + + If input label contains values other than 0 and 1, it will be cast to bool. You can find the definations here: From 63309941b3f13d56afb863bf7c257ee284857028 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Mon, 9 Oct 2017 17:51:17 +0800 Subject: [PATCH 7/7] pull develop and update --- paddle/operators/auc_op.cc | 21 +++++++++++---------- paddle/operators/auc_op.h | 6 ++---- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/paddle/operators/auc_op.cc b/paddle/operators/auc_op.cc index e7275a5933c03..d8cecf0957c6c 100644 --- a/paddle/operators/auc_op.cc +++ b/paddle/operators/auc_op.cc @@ -22,18 +22,19 @@ class AucOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Inference"), - "Input of Inference must be initialized."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input of Inference must be initialized."); - auto *inference = ctx.Input("Inference"); - auto *label = ctx.Input("Label"); - - PADDLE_ENFORCE_EQ(inference->dims(), label->dims(), + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Inference"), + "Input of Inference must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Label"), + "Input of Label must be initialized."); + auto inference_dim = ctx->GetInputDim("Inference"); + auto label_dim = ctx->GetInputDim("Label"); + + PADDLE_ENFORCE_EQ(inference_dim, label_dim, "inference and label should have same shape"); - ctx.Output("AUC")->Resize({1}); + ctx->SetOutputDim("AUC", {1}); + ctx->ShareLoD("Inference", /*->*/ "AUC"); } }; diff --git a/paddle/operators/auc_op.h b/paddle/operators/auc_op.h index ad5585be30481..be6ef29d5f6cf 100644 --- a/paddle/operators/auc_op.h +++ b/paddle/operators/auc_op.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" @@ -27,7 +26,7 @@ template ; template -class AucKernel : public framework::OpKernel { +class AucKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* inference = ctx.Input("Inference"); @@ -61,8 +60,7 @@ class AucKernel : public framework::OpKernel { } // Create local tensor for storing the curve: TP, FN, TN, FP - // TODO(typhoonzero): put these tensors in Scope - // TODO(typhoonzero): use op to caculate these values. + // TODO(typhoonzero): use eigen op to caculate these values. Tensor true_positive, false_positive, true_negative, false_negative; true_positive.Resize({num_thresholds});