Skip to content

Commit

Permalink
add op type in check nan/inf (PaddlePaddle#15986)
Browse files Browse the repository at this point in the history
* add op name in check nan/inf, test=develop
  • Loading branch information
seiriosPlus authored and ceci3 committed Mar 4, 2019
1 parent 2bdf446 commit 7b0875e
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
const RuntimeContext& ctx_;
};

static void CheckTensorNANOrInf(const std::string& name,
static void CheckTensorNANOrInf(const std::string& op_type,
const std::string& name,
const framework::Tensor& tensor) {
if (tensor.memory_size() == 0) {
return;
Expand All @@ -892,9 +893,9 @@ static void CheckTensorNANOrInf(const std::string& name,
return;
}
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
"Tensor %s contains Inf", name);
"Operator %s output Tensor %s contains Inf", op_type, name);
PADDLE_ENFORCE(!framework::TensorContainsNAN(tensor),
"Tensor %s contains NAN", name);
"Operator %s output Tensor %s contains NAN", op_type, name);
}

void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
Expand Down Expand Up @@ -988,9 +989,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
auto* var = exec_scope.FindVar(vname);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
CheckTensorNANOrInf(type_, vname, var->Get<framework::LoDTensor>());
} else if (var->IsType<framework::SelectedRows>()) {
CheckTensorNANOrInf(vname, var->Get<framework::SelectedRows>().value());
CheckTensorNANOrInf(type_, vname,
var->Get<framework::SelectedRows>().value());
}
}
}
Expand Down

0 comments on commit 7b0875e

Please sign in to comment.