Skip to content

Commit

Permalink
[cherry-pick] Optimize performance of dygraph (#42231, #42253) (#42309)
Browse files Browse the repository at this point in the history
* Optimize the performanece of sum api (#42231)

* optimize the performanece of sum api

* optimize IsDenseTensorInput

* remove debug log

* Add move construct for KernelSignature (#42253)

* add move construct for KernelSignature

* add noexcept

* fix cherry-pick problem
  • Loading branch information
zyfncg authored Apr 28, 2022
1 parent 9e1aa11 commit 69a92b7
Show file tree
Hide file tree
Showing 19 changed files with 95 additions and 33 deletions.
12 changes: 7 additions & 5 deletions paddle/fluid/framework/infershape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
}

bool IsDenseTensorInput(const std::string& name) const override {
auto var_type = ctx_.GetInputVarType(name);
return var_type == proto::VarType::LOD_TENSOR;
}

bool IsDenseTensorInputs(const std::string& name) const override {
auto var_types = ctx_.GetInputsVarType(name);
return std::all_of(var_types.begin(), var_types.end(),
[](const proto::VarType::Type& type) {
Expand All @@ -78,11 +83,8 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
}

bool IsSelectedRowsInput(const std::string& name) const override {
auto var_types = ctx_.GetInputsVarType(name);
return std::all_of(var_types.begin(), var_types.end(),
[](const proto::VarType::Type& type) {
return type == proto::VarType::SELECTED_ROWS;
});
auto var_type = ctx_.GetInputVarType(name);
return var_type == proto::VarType::SELECTED_ROWS;
}

bool IsDenseTensorVectorInput(const std::string& name) const override {
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/framework/new_executor/new_executor_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ std::vector<DDim> InterpretercoreInferShapeContext::GetInputsDim(
return GetDims(vars);
}

proto::VarType::Type InterpretercoreInferShapeContext::GetInputVarType(
const std::string& name) const {
return GetVarType(InputVars(name).at(0));
}

std::vector<proto::VarType::Type>
InterpretercoreInferShapeContext::GetInputsVarType(
const std::string& name) const {
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/new_executor/new_executor_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext {

std::vector<DDim> GetInputsDim(const std::string& name) const override;

proto::VarType::Type GetInputVarType(const std::string& name) const override;

std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override;

Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,10 @@ class CompileTimeInferShapeContext : public InferShapeContext {

bool IsRunMKLDNNKernel() const override;

proto::VarType::Type GetInputVarType(const std::string &name) const override {
return GetVarType(Inputs(name).at(0));
}

std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const override {
return GetVarTypes(Inputs(name));
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,10 @@ class RuntimeInferShapeContext : public InferShapeContext {
return GetDims(vars);
}

proto::VarType::Type GetInputVarType(const std::string& name) const override {
return GetVarType(InputVars(name).at(0));
}

std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override {
return GetVarTypes(InputVars(name));
Expand Down
11 changes: 7 additions & 4 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,17 +479,20 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
}

bool IsDenseTensorInput(const std::string& name) const override {
const auto* var = ctx_.InputVar(name);
return var->IsType<phi::DenseTensor>();
}

bool IsDenseTensorInputs(const std::string& name) const override {
auto vars = ctx_.MultiInputVar(name);
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
return var->IsType<phi::DenseTensor>();
});
}

bool IsSelectedRowsInput(const std::string& name) const override {
auto vars = ctx_.MultiInputVar(name);
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
return var->IsType<phi::SelectedRows>();
});
const auto* var = ctx_.InputVar(name);
return var->IsType<phi::SelectedRows>();
}

bool IsDenseTensorVectorInput(const std::string& name) const override {
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class InferShapeContext {
virtual bool HasOutput(const std::string &name) const = 0;
virtual bool HasAttr(const std::string &name) const = 0;

virtual proto::VarType::Type GetInputVarType(
const std::string &name) const = 0;
virtual std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const = 0;
virtual std::vector<proto::VarType::Type> GetOutputsVarType(
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/imperative/infer_shape_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,15 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
return vec_res;
}

framework::proto::VarType::Type GetInputVarType(
const std::string& name) const override {
auto it = var_map_in_->find(name);
PADDLE_ENFORCE_NE(
it, var_map_in_->end(),
platform::errors::NotFound("can not find [%s] in input", name));
return framework::ToVarType(it->second[0]->Var().Type());
}

std::vector<framework::proto::VarType::Type> GetInputsVarType(
const std::string& name) const override {
std::vector<framework::proto::VarType::Type> vec_res;
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference {
BOOST_GET_CONST(int, ctx->GetAttr("out_dtype")));
if (data_type >= 0) {
ctx->SetOutputDataType("Out", data_type);
} else {
auto x_type = ctx->GetInputDataType("X");
if (x_type == framework::proto::VarType::BOOL ||
x_type == framework::proto::VarType::INT32) {
ctx->SetOutputDataType("Out", framework::proto::VarType::INT64);
}
}
}
};
Expand Down
6 changes: 1 addition & 5 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1204,11 +1204,7 @@ paddle::experimental::DataType CastPyArg2DataType(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
if (obj == Py_None) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"data_type, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
return paddle::experimental::DataType::UNDEFINED;
}

framework::proto::VarType::Type type = CastPyArg2ProtoType(obj, arg_pos);
Expand Down
6 changes: 6 additions & 0 deletions paddle/infrt/dialect/phi/pass/proto_arg_map_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ bool ProtoArgumentMappingContext::IsDenseTensorInput(
const std::string& name) const {
return true;
}

bool ProtoArgumentMappingContext::IsDenseTensorInputs(
const std::string& name) const {
return true;
}

bool ProtoArgumentMappingContext::IsSelectedRowsInput(
const std::string& name) const {
return false;
Expand Down
1 change: 1 addition & 0 deletions paddle/infrt/dialect/phi/pass/proto_arg_map_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ProtoArgumentMappingContext : public ::phi::ArgumentMappingContext {
size_t OutputSize(const std::string& name) const override;

bool IsDenseTensorInput(const std::string& name) const override;
bool IsDenseTensorInputs(const std::string& name) const override;
bool IsSelectedRowsInput(const std::string& name) const override;
bool IsDenseTensorVectorInput(const std::string& name) const override;

Expand Down
21 changes: 21 additions & 0 deletions paddle/phi/core/compat/arg_map_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,33 @@ struct KernelSignature {

// TODO(chenweihang): add assign constructor to solve windows compile
// problem, remove it later
KernelSignature(const KernelSignature& other)
: name(other.name),
input_names(other.input_names),
attr_names(other.attr_names),
output_names(other.output_names) {}

KernelSignature(KernelSignature&& other) noexcept
: name(other.name),
input_names(std::move(other.input_names)),
attr_names(std::move(other.attr_names)),
output_names(std::move(other.output_names)) {}

KernelSignature& operator=(const KernelSignature& other) {
name = other.name;
input_names = other.input_names;
attr_names = other.attr_names;
output_names = other.output_names;
return *this;
}

KernelSignature& operator=(KernelSignature&& other) noexcept {
name = other.name;
input_names.swap(other.input_names);
attr_names.swap(other.attr_names);
output_names.swap(other.output_names);
return *this;
}
};

std::ostream& operator<<(std::ostream& os, KernelSignature signature);
Expand All @@ -86,6 +106,7 @@ class ArgumentMappingContext {
virtual size_t OutputSize(const std::string& name) const = 0;

virtual bool IsDenseTensorInput(const std::string& name) const = 0;
virtual bool IsDenseTensorInputs(const std::string& name) const = 0;
virtual bool IsSelectedRowsInput(const std::string& name) const = 0;
// For compatibility with LoDTensorArray
virtual bool IsDenseTensorVectorInput(const std::string& name) const = 0;
Expand Down
3 changes: 1 addition & 2 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2254,8 +2254,7 @@ void SumRawInferMeta(const MetaTensor& x,
if (dtype != DataType::UNDEFINED) {
out_dtype = dtype;
} else {
if (x.dtype() == DataType::BOOL || x.dtype() == DataType::INT32 ||
x.dtype() == DataType::INT64) {
if (x.dtype() == DataType::BOOL || x.dtype() == DataType::INT32) {
out_dtype = DataType::INT64;
} else {
out_dtype = x.dtype();
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/cpu/reduce_sum_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ void SumRawKernel(const Context& dev_ctx,
bool reduce_all,
DataType out_dtype,
DenseTensor* out) {
if (out_dtype == DataType::UNDEFINED && out->dtype() != x.dtype()) {
out_dtype = out->dtype();
}
phi::Reduce<CPUContext, T, phi::funcs::SumFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/gpu/reduce_sum_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ void SumRawKernel(const Context& dev_ctx,
bool reduce_all,
DataType out_dtype,
DenseTensor* out) {
if (out_dtype == DataType::UNDEFINED && out->dtype() != x.dtype()) {
out_dtype = out->dtype();
}
phi::Reduce<T, kps::AddFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/compat/sum_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
namespace phi {

KernelSignature SumOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) {
if (ctx.IsDenseTensorInputs("X")) {
return KernelSignature("add_n", {"X"}, {}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/tests/ops/test_op_signature.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext {
return dense_tensor_inputs.count(name) > 0;
}

bool IsDenseTensorInputs(const std::string& name) const override {
return dense_tensor_inputs.count(name) > 0;
}

bool IsSelectedRowsInput(const std::string& name) const override {
return selected_rows_inputs.count(name) > 0;
}
Expand Down
24 changes: 8 additions & 16 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,33 +899,25 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
else:
reduce_all_flag = False

def get_dtype(x, dtype):
if dtype is not None:
return (True, dtype)
src_type = convert_dtype(x.dtype)
if src_type in ['bool','int32', 'int64']:
return (True, 'int64')
return (False, src_type)

dtype_flag, dtype = get_dtype(x, dtype)
dtype_flag = False
if dtype is not None:
dtype_flag = True
dtype = convert_np_dtype_to_dtype_(dtype)

if in_dygraph_mode():
if reduce_all_flag:
axis = range(len(x.shape))
else:
axis = axis if axis != None and axis != [] else [0]

out_dtype = convert_np_dtype_to_dtype_(dtype)
out = _C_ops.final_state_sum(x, axis, out_dtype, keepdim)
return out
return _C_ops.final_state_sum(x, axis, dtype, keepdim)

if _in_legacy_dygraph():
axis = axis if axis != None and axis != [] else [0]
if dtype_flag:
return _C_ops.reduce_sum(x, 'dim', axis, 'keep_dim', keepdim,
'reduce_all', reduce_all_flag, 'in_dtype',
x.dtype, 'out_dtype',
convert_np_dtype_to_dtype_(dtype))
x.dtype, 'out_dtype', dtype)
else:
return _C_ops.reduce_sum(x, 'dim', axis, 'keep_dim', keepdim,
'reduce_all', reduce_all_flag)
Expand All @@ -939,7 +931,7 @@ def get_dtype(x, dtype):
if dtype_flag:
attrs.update({
'in_dtype': x.dtype,
'out_dtype': convert_np_dtype_to_dtype_(dtype)
'out_dtype': dtype
})

check_variable_and_dtype(
Expand All @@ -953,7 +945,7 @@ def get_dtype(x, dtype):
helper = LayerHelper('sum', **locals())
if dtype_flag:
out = helper.create_variable_for_type_inference(
dtype=convert_np_dtype_to_dtype_(dtype))
dtype=dtype)
else:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
Expand Down

0 comments on commit 69a92b7

Please sign in to comment.