Skip to content

Commit

Permalink
pull develop and update
Browse files Browse the repository at this point in the history
  • Loading branch information
typhoonzero committed Oct 9, 2017
1 parent 2824352 commit 6330994
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
21 changes: 11 additions & 10 deletions paddle/operators/auc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,19 @@ class AucOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Inference"),
"Input of Inference must be initialized.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input of Inference must be initialized.");
auto *inference = ctx.Input<framework::Tensor>("Inference");
auto *label = ctx.Input<framework::Tensor>("Label");

PADDLE_ENFORCE_EQ(inference->dims(), label->dims(),
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Inference"),
"Input of Inference must be initialized.");
PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input of Label must be initialized.");
auto inference_dim = ctx->GetInputDim("Inference");
auto label_dim = ctx->GetInputDim("Label");

PADDLE_ENFORCE_EQ(inference_dim, label_dim,
"inference and label should have same shape");

ctx.Output<framework::LoDTensor>("AUC")->Resize({1});
ctx->SetOutputDim("AUC", {1});
ctx->ShareLoD("Inference", /*->*/ "AUC");
}
};

Expand Down
6 changes: 2 additions & 4 deletions paddle/operators/auc_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <iostream>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"

Expand All @@ -27,7 +26,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

template <typename Place, typename T>
class AucKernel : public framework::OpKernel {
class AucKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* inference = ctx.Input<Tensor>("Inference");
Expand Down Expand Up @@ -61,8 +60,7 @@ class AucKernel : public framework::OpKernel {
}

// Create local tensor for storing the curve: TP, FN, TN, FP
// TODO(typhoonzero): put these tensors in Scope
// TODO(typhoonzero): use op to caculate these values.
// TODO(typhoonzero): use eigen op to caculate these values.
Tensor true_positive, false_positive, true_negative, false_negative;

true_positive.Resize({num_thresholds});
Expand Down

0 comments on commit 6330994

Please sign in to comment.