Skip to content

Commit

Permalink
Merge pull request #14149 from chengduoZH/fix_sum_op_bug_relase
Browse files Browse the repository at this point in the history
Fix sum op's GetExpectedKernelType
  • Loading branch information
panyx0718 authored Oct 31, 2018
2 parents c5591f7 + 618d7e3 commit 66024e9
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
9 changes: 4 additions & 5 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,11 @@ static bool VarIsTensor(const Variable* var) {
return var->IsType<LoDTensor>() || var->IsType<SelectedRows>();
}

static const Tensor* GetTensorFromVar(Variable* var) {
const Tensor* GetTensorFromVar(const Variable* var) {
if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>();
return static_cast<const Tensor*>(&(var->Get<LoDTensor>()));
} else if (var->IsType<SelectedRows>()) {
return var->GetMutable<SelectedRows>()->mutable_value();
return &(var->Get<SelectedRows>().value());
} else {
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
var->Type().name());
Expand Down Expand Up @@ -415,8 +415,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name);
return var == nullptr ? nullptr
: GetTensorFromVar(const_cast<Variable*>(var));
return var == nullptr ? nullptr : GetTensorFromVar(var);
}

template <>
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ inline std::string GradVarName(const std::string& var_name) {
}

proto::VarType::Type GetDataTypeOfVar(const Variable* var);
const Tensor* GetTensorFromVar(const Variable* var);

class OperatorBase;
class ExecutionContext;
Expand Down
9 changes: 5 additions & 4 deletions paddle/fluid/operators/sum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,15 @@ class SumOp : public framework::OperatorWithKernel {
if (x_vars[0]->IsType<framework::LoDTensor>()) {
int dtype = -1;
for (auto& x_var : x_vars) {
auto& lod_tensor = x_var->Get<framework::LoDTensor>();
if (lod_tensor.numel() == 0) {
// FIXME(zcd): The input x_var may be SelectedRows or LoDTensor.
auto tensor = framework::GetTensorFromVar(x_var);
if (tensor->numel() == 0) {
continue;
}
if (dtype == -1) {
dtype = framework::ToDataType(lod_tensor.type());
dtype = framework::ToDataType(tensor->type());
} else {
PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(lod_tensor.type()));
PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(tensor->type()));
}
}
PADDLE_ENFORCE_NE(dtype, -1,
Expand Down

0 comments on commit 66024e9

Please sign in to comment.