Skip to content

Commit

Permalink
optimize the pybind in dygraph (PaddlePaddle#42343)
Browse files Browse the repository at this point in the history
  • Loading branch information
zyfncg committed Apr 28, 2022
1 parent c8b6654 commit fa1d84f
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 24 deletions.
22 changes: 13 additions & 9 deletions paddle/fluid/imperative/tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,30 +220,34 @@ void Tracer::TraceOpImpl(const std::string& type,
attr_checker == nullptr ? empty_attrs_map
: attr_checker->GetDefaultAttrMap();

NameVarMap<VarType> new_ins = ins;
std::unique_ptr<NameVarMap<VarType>> ins_amp = nullptr;
if (amp_level_ == AmpLevel::O1) {
if (amp_dtype_ == phi::DataType::FLOAT16) {
const auto& tracer = imperative::GetCurrentTracer();
new_ins =
imperative::AutoTuneLayout<VarType>(type, ins, outs, &attrs, tracer);
VLOG(5) << "Float16 Auto Mixed Precision O1 run operator: " << type;
new_ins = AutoCastInputs<VarType>(type, new_ins);
ins_amp = std::make_unique<NameVarMap<VarType>>(
AutoCastInputs<VarType>(type, imperative::AutoTuneLayout<VarType>(
type, ins, outs, &attrs, tracer)));
} else if (amp_dtype_ == phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O1 run operator: " << type;
new_ins = AutoCastBF16Inputs<VarType>(type, ins);
ins_amp = std::make_unique<NameVarMap<VarType>>(
AutoCastBF16Inputs<VarType>(type, ins));
}
} else if (amp_level_ == AmpLevel::O2) {
if (amp_dtype_ == phi::DataType::FLOAT16) {
const auto& tracer = imperative::GetCurrentTracer();
new_ins =
imperative::AutoTuneLayout<VarType>(type, ins, outs, &attrs, tracer);
VLOG(5) << "Float16 Auto Mixed Precision O2 run operator: " << type;
new_ins = CastPureFp16Inputs<VarType>(type, new_ins);
ins_amp =
std::make_unique<NameVarMap<VarType>>(CastPureFp16Inputs<VarType>(
type, imperative::AutoTuneLayout<VarType>(type, ins, outs, &attrs,
tracer)));
} else if (amp_dtype_ == phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O2 run operator: " << type;
new_ins = CastPureBf16Inputs<VarType>(type, ins);
ins_amp = std::make_unique<NameVarMap<VarType>>(
CastPureBf16Inputs<VarType>(type, ins));
}
}
const auto& new_ins = ins_amp == nullptr ? ins : *ins_amp;

try {
if (platform::is_gpu_place(place)) {
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/pybind/op_function_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
std::vector<int> value;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
value.reserve(len);
PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
Expand All @@ -298,6 +299,7 @@ std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
}
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
value.reserve(len);
PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
Expand All @@ -314,6 +316,7 @@ std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
}
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
value.reserve(len);
PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
Expand Down
25 changes: 13 additions & 12 deletions paddle/fluid/pybind/op_function_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)";

const char* CAST_VAR_TEMPLATE = R"(
auto %s = GetVarBaseFromArgs("%s", "%s", args, %d, %s);)";
auto %s = GetVarBaseFromArgs(op_type, "%s", args, %d, %s);)";

const char* CAST_VAR_LIST_TEMPLATE = R"(
auto %s = GetVarBaseListFromArgs("%s", "%s", args, %d, %s);)";
auto %s = GetVarBaseListFromArgs(op_type, "%s", args, %d, %s);)";

const char* CAST_SIZE_T_TEMPLATE = R"(
auto %s = GetUnsignedLongFromArgs("%s", "%s", args, %d, %s);)";
auto %s = GetUnsignedLongFromArgs(op_type, "%s", args, %d, %s);)";

const char* ARG_TEMPLATE = R"(const %s& %s)";

Expand Down Expand Up @@ -126,16 +126,17 @@ static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs)
PyThreadState *tstate = nullptr;
try
{
std::string op_type = "%s";
platform::RecordEvent op_type_record_event("%s pybind_imperative_func");
%s
framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs("%s", args, %d, PyTuple_GET_SIZE(args) , attrs);
ConstructAttrMapFromPyArgs(op_type, args, %d, PyTuple_GET_SIZE(args) , attrs);
tstate = PyEval_SaveThread();
%s
imperative::NameVarBaseMap outs = %s;
imperative::NameVarBaseMap ins = %s;
%s
imperative::GetCurrentTracer()->TraceOp("%s", ins, outs, attrs, {%s});
imperative::GetCurrentTracer()->TraceOp(op_type, ins, outs, attrs, {%s});
PyEval_RestoreThread(tstate);
tstate = nullptr;
%s
Expand Down Expand Up @@ -208,8 +209,8 @@ std::string GenerateOpFunctionsBody(
const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = input.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, op_type,
in_name, arg_idx++, dispensable);
ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, in_name,
arg_idx++, dispensable);

if (input.dispensable()) {
const auto in_template = input.duplicable()
Expand Down Expand Up @@ -279,8 +280,8 @@ std::string GenerateOpFunctionsBody(
const auto in_cast_type =
output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, op_type,
out_name, arg_idx++, dispensable);
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, out_name,
arg_idx++, dispensable);
} else if (use_inplace_strategy && inplace_map.count(out_name)) {
PADDLE_ENFORCE_NE(
inplace_map[out_name], "",
Expand Down Expand Up @@ -329,7 +330,7 @@ std::string GenerateOpFunctionsBody(

auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str +=
paddle::string::Sprintf(CAST_SIZE_T_TEMPLATE, out_num_str, op_type,
paddle::string::Sprintf(CAST_SIZE_T_TEMPLATE, out_num_str,
out_num_str, arg_idx++, dispensable);
} else {
outs_initializer +=
Expand Down Expand Up @@ -375,11 +376,11 @@ std::string GenerateOpFunctionsBody(

// generate op funtcion body
auto op_function_str = paddle::string::Sprintf(
OP_FUNCTION_TEMPLATE, func_name, op_type, ins_cast_str, op_type,
OP_FUNCTION_TEMPLATE, func_name, op_type, op_type, ins_cast_str,
input_args_num, inplace_strategy_str, outs_initializer, ins_initializer,
ins_initializer_with_null + outs_initializer_with_null +
view_strategy_str,
op_type, inplace_mapping_str, return_str);
inplace_mapping_str, return_str);

return op_function_str;
}
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/core/compat/arg_map_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ struct KernelSignature {

KernelSignature& operator=(KernelSignature&& other) noexcept {
name = other.name;
input_names.swap(other.input_names);
attr_names.swap(other.attr_names);
output_names.swap(other.output_names);
input_names = std::move(other.input_names);
attr_names = std::move(other.attr_names);
output_names = std::move(other.output_names);
return *this;
}
};
Expand Down

0 comments on commit fa1d84f

Please sign in to comment.