-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4063 from typhoonzero/auc_op
Auc op
- Loading branch information
Showing
3 changed files
with
286 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/auc_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class AucOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
void InferShape(framework::InferShapeContextBase *ctx) const override { | ||
PADDLE_ENFORCE(ctx->HasInput("Inference"), | ||
"Input of Inference must be initialized."); | ||
PADDLE_ENFORCE(ctx->HasInput("Label"), | ||
"Input of Label must be initialized."); | ||
auto inference_dim = ctx->GetInputDim("Inference"); | ||
auto label_dim = ctx->GetInputDim("Label"); | ||
|
||
PADDLE_ENFORCE_EQ(inference_dim, label_dim, | ||
"inference and label should have same shape"); | ||
|
||
ctx->SetOutputDim("AUC", {1}); | ||
ctx->ShareLoD("Inference", /*->*/ "AUC"); | ||
} | ||
}; | ||
|
||
class AucOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
AucOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("Inference", | ||
"A floating point tensor of arbitrary shape and whose values" | ||
"are in the range [0, 1]."); | ||
AddInput("Label", | ||
"A tensor whose shape matches " | ||
"Inference. Will be cast to bool."); | ||
// TODO(typhoonzero): support weight input | ||
AddOutput("AUC", | ||
"A scalar representing the " | ||
"current area-under-curve."); | ||
|
||
AddAttr<std::string>("curve", "Curve type, can be 'ROC' or 'PR'.") | ||
.SetDefault("ROC"); | ||
AddAttr<int>("num_thresholds", | ||
"The number of thresholds to use when discretizing the" | ||
" roc curve.") | ||
.SetDefault(200); | ||
|
||
AddComment( | ||
R"DOC(Computes the AUC according forward output and label. | ||
Best to use for binary classification evaluations. | ||
If input label contains values other than 0 and 1, it will be cast | ||
to bool. | ||
You can find the definations here: | ||
https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve | ||
Possible curves are: | ||
- ROC: Receiver operating characteristic | ||
- PR: Precision Recall | ||
)DOC"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP_WITHOUT_GRADIENT(auc, ops::AucOp, ops::AucOpMaker); | ||
REGISTER_OP_CPU_KERNEL(auc, ops::AucKernel<paddle::platform::CPUPlace, float>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
#include "paddle/framework/eigen.h" | ||
#include "paddle/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using Tensor = framework::Tensor; | ||
|
||
template <typename T, int MajorType = Eigen::RowMajor, | ||
typename IndexType = Eigen::DenseIndex> | ||
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; | ||
|
||
template <typename Place, typename T> | ||
class AucKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
auto* inference = ctx.Input<Tensor>("Inference"); | ||
auto* label = ctx.Input<Tensor>("Label"); | ||
auto* auc = ctx.Output<Tensor>("AUC"); | ||
|
||
float* auc_data = auc->mutable_data<float>(ctx.GetPlace()); | ||
|
||
std::string curve = ctx.Attr<std::string>("curve"); | ||
int num_thresholds = ctx.Attr<int>("num_thresholds"); | ||
std::vector<float> thresholds_list; | ||
thresholds_list.reserve(num_thresholds); | ||
for (int i = 1; i < num_thresholds - 1; i++) { | ||
thresholds_list[i] = (float)i / (num_thresholds - 1); | ||
} | ||
const float kEpsilon = 1e-7; | ||
thresholds_list[0] = 0.0f - kEpsilon; | ||
thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon; | ||
|
||
size_t num_samples = inference->numel(); | ||
|
||
const T* inference_data = inference->data<T>(); | ||
Tensor label_casted; | ||
label_casted.Resize(label->dims()); | ||
bool* label_casted_data = label_casted.mutable_data<bool>(ctx.GetPlace()); | ||
|
||
const int* label_data = label->data<int>(); | ||
// cast label_data to bool | ||
for (size_t i = 0; i < num_samples; i++) { | ||
label_casted_data[i] = static_cast<bool>(label_data[i]); | ||
} | ||
|
||
// Create local tensor for storing the curve: TP, FN, TN, FP | ||
// TODO(typhoonzero): use eigen op to caculate these values. | ||
Tensor true_positive, false_positive, true_negative, false_negative; | ||
|
||
true_positive.Resize({num_thresholds}); | ||
false_negative.Resize({num_thresholds}); | ||
true_negative.Resize({num_thresholds}); | ||
false_positive.Resize({num_thresholds}); | ||
|
||
int* tp_data = true_positive.mutable_data<int>(ctx.GetPlace()); | ||
int* fn_data = false_negative.mutable_data<int>(ctx.GetPlace()); | ||
int* tn_data = true_negative.mutable_data<int>(ctx.GetPlace()); | ||
int* fp_data = false_positive.mutable_data<int>(ctx.GetPlace()); | ||
|
||
for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) { | ||
// caculate TP, FN, TN, FP for current thresh | ||
int tp = 0, fn = 0, tn = 0, fp = 0; | ||
for (size_t i = 0; i < num_samples; i++) { | ||
if (label_casted_data[i]) { | ||
if (inference_data[i] >= (thresholds_list[idx_thresh])) { | ||
tp++; | ||
} else { | ||
fn++; | ||
} | ||
} else { | ||
if (inference_data[i] >= (thresholds_list[idx_thresh])) { | ||
fp++; | ||
} else { | ||
tn++; | ||
} | ||
} | ||
} | ||
// store rates | ||
tp_data[idx_thresh] = tp; | ||
fn_data[idx_thresh] = fn; | ||
tn_data[idx_thresh] = tn; | ||
fp_data[idx_thresh] = fp; | ||
} | ||
// epsilon to avoid divide by zero. | ||
float epsilon = 1e-6; | ||
// Riemann sum to caculate auc. | ||
Tensor tp_rate, fp_rate, rec_rate; | ||
tp_rate.Resize({num_thresholds}); | ||
fp_rate.Resize({num_thresholds}); | ||
rec_rate.Resize({num_thresholds}); | ||
float* tp_rate_data = tp_rate.mutable_data<float>(ctx.GetPlace()); | ||
float* fp_rate_data = fp_rate.mutable_data<float>(ctx.GetPlace()); | ||
float* rec_rate_data = rec_rate.mutable_data<float>(ctx.GetPlace()); | ||
for (int i = 0; i < num_thresholds; i++) { | ||
tp_rate_data[i] = | ||
((float)tp_data[i] + epsilon) / (tp_data[i] + fn_data[i] + epsilon); | ||
fp_rate_data[i] = (float)fp_data[i] / (fp_data[i] + tn_data[i] + epsilon); | ||
rec_rate_data[i] = | ||
((float)tp_data[i] + epsilon) / (tp_data[i] + fp_data[i] + epsilon); | ||
} | ||
*auc_data = 0.0f; | ||
if (curve == "ROC") { | ||
for (int i = 0; i < num_thresholds - 1; i++) { | ||
auto dx = fp_rate_data[i] - fp_rate_data[i + 1]; | ||
auto y = (tp_rate_data[i] + tp_rate_data[i + 1]) / 2.0f; | ||
*auc_data = *auc_data + dx * y; | ||
} | ||
} else if (curve == "PR") { | ||
for (int i = 1; i < num_thresholds; i++) { | ||
auto dx = tp_rate_data[i] - tp_rate_data[i - 1]; | ||
auto y = (rec_rate_data[i] + rec_rate_data[i - 1]) / 2.0f; | ||
*auc_data = *auc_data + dx * y; | ||
} | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import unittest | ||
import numpy as np | ||
from op_test import OpTest | ||
|
||
|
||
class TestAucOp(OpTest): | ||
def setUp(self): | ||
self.op_type = "auc" | ||
pred = np.random.random((128)).astype("float32") | ||
labels = np.random.randint(0, 2, (128, )) | ||
num_thresholds = 200 | ||
self.inputs = {'Inference': pred, 'Label': labels} | ||
self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds} | ||
# NOTE: sklearn use a different way to generate thresholds | ||
# which will cause the result differs slightly: | ||
# from sklearn.metrics import roc_curve, auc | ||
# fpr, tpr, thresholds = roc_curve(labels, pred) | ||
# auc_value = auc(fpr, tpr) | ||
# we caculate AUC again using numpy for testing | ||
kepsilon = 1e-7 # to account for floating point imprecisions | ||
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) | ||
for i in range(num_thresholds - 2)] | ||
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] | ||
|
||
# caculate TP, FN, TN, FP count | ||
tp_list = np.ndarray((num_thresholds, )) | ||
fn_list = np.ndarray((num_thresholds, )) | ||
tn_list = np.ndarray((num_thresholds, )) | ||
fp_list = np.ndarray((num_thresholds, )) | ||
for idx_thresh, thresh in enumerate(thresholds): | ||
tp, fn, tn, fp = 0, 0, 0, 0 | ||
for i, lbl in enumerate(labels): | ||
if lbl: | ||
if pred[i] >= thresh: | ||
tp += 1 | ||
else: | ||
fn += 1 | ||
else: | ||
if pred[i] >= thresh: | ||
fp += 1 | ||
else: | ||
tn += 1 | ||
tp_list[idx_thresh] = tp | ||
fn_list[idx_thresh] = fn | ||
tn_list[idx_thresh] = tn | ||
fp_list[idx_thresh] = fp | ||
|
||
epsilon = 1e-6 | ||
tpr = (tp_list.astype("float32") + epsilon) / ( | ||
tp_list + fn_list + epsilon) | ||
fpr = fp_list.astype("float32") / (fp_list + tn_list + epsilon) | ||
rec = (tp_list.astype("float32") + epsilon) / ( | ||
tp_list + fp_list + epsilon) | ||
|
||
x = fpr[:num_thresholds - 1] - fpr[1:] | ||
y = (tpr[:num_thresholds - 1] + tpr[1:]) / 2.0 | ||
auc_value = np.sum(x * y) | ||
|
||
self.outputs = {'AUC': auc_value} | ||
|
||
def test_check_output(self): | ||
self.check_output() | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |