Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize the pybind in dygraph #42343

Merged
merged 1 commit into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion paddle/fluid/framework/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ void SetTensorToVariable(const Variable &in_var, const Tensor &tensor,
#ifdef PADDLE_WITH_MKLDNN
tran_lod_tensor->set_mem_desc(in_lod_tensor.mem_desc());
#endif
tran_lod_tensor->set_layout(in_lod_tensor.layout());
tran_lod_tensor->ShareDataWith(tensor);
} else if (in_var.IsType<phi::SelectedRows>()) {
auto &in_selected_rows = in_var.Get<phi::SelectedRows>();
Expand Down
22 changes: 13 additions & 9 deletions paddle/fluid/imperative/tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,30 +220,34 @@ void Tracer::TraceOpImpl(const std::string& type,
attr_checker == nullptr ? empty_attrs_map
: attr_checker->GetDefaultAttrMap();

NameVarMap<VarType> new_ins = ins;
std::unique_ptr<NameVarMap<VarType>> ins_amp = nullptr;
if (amp_level_ == AmpLevel::O1) {
if (amp_dtype_ == phi::DataType::FLOAT16) {
const auto& tracer = imperative::GetCurrentTracer();
new_ins =
imperative::AutoTuneLayout<VarType>(type, ins, outs, &attrs, tracer);
VLOG(5) << "Float16 Auto Mixed Precision O1 run operator: " << type;
new_ins = AutoCastInputs<VarType>(type, new_ins);
ins_amp = std::make_unique<NameVarMap<VarType>>(
AutoCastInputs<VarType>(type, imperative::AutoTuneLayout<VarType>(
type, ins, outs, &attrs, tracer)));
} else if (amp_dtype_ == phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O1 run operator: " << type;
new_ins = AutoCastBF16Inputs<VarType>(type, ins);
ins_amp = std::make_unique<NameVarMap<VarType>>(
AutoCastBF16Inputs<VarType>(type, ins));
}
} else if (amp_level_ == AmpLevel::O2) {
if (amp_dtype_ == phi::DataType::FLOAT16) {
const auto& tracer = imperative::GetCurrentTracer();
new_ins =
imperative::AutoTuneLayout<VarType>(type, ins, outs, &attrs, tracer);
VLOG(5) << "Float16 Auto Mixed Precision O2 run operator: " << type;
new_ins = CastPureFp16Inputs<VarType>(type, new_ins);
ins_amp =
std::make_unique<NameVarMap<VarType>>(CastPureFp16Inputs<VarType>(
type, imperative::AutoTuneLayout<VarType>(type, ins, outs, &attrs,
tracer)));
} else if (amp_dtype_ == phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O2 run operator: " << type;
new_ins = CastPureBf16Inputs<VarType>(type, ins);
ins_amp = std::make_unique<NameVarMap<VarType>>(
CastPureBf16Inputs<VarType>(type, ins));
}
}
const auto& new_ins = ins_amp == nullptr ? ins : *ins_amp;

try {
if (platform::is_gpu_place(place)) {
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/pybind/op_function_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
std::vector<int> value;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
value.reserve(len);
PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
Expand All @@ -298,6 +299,7 @@ std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
}
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
value.reserve(len);
PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
Expand All @@ -314,6 +316,7 @@ std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
}
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
value.reserve(len);
PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
Expand Down
25 changes: 13 additions & 12 deletions paddle/fluid/pybind/op_function_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)";

const char* CAST_VAR_TEMPLATE = R"(
auto %s = GetVarBaseFromArgs("%s", "%s", args, %d, %s);)";
auto %s = GetVarBaseFromArgs(op_type, "%s", args, %d, %s);)";

const char* CAST_VAR_LIST_TEMPLATE = R"(
auto %s = GetVarBaseListFromArgs("%s", "%s", args, %d, %s);)";
auto %s = GetVarBaseListFromArgs(op_type, "%s", args, %d, %s);)";

const char* CAST_SIZE_T_TEMPLATE = R"(
auto %s = GetUnsignedLongFromArgs("%s", "%s", args, %d, %s);)";
auto %s = GetUnsignedLongFromArgs(op_type, "%s", args, %d, %s);)";

const char* ARG_TEMPLATE = R"(const %s& %s)";

Expand Down Expand Up @@ -126,16 +126,17 @@ static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs)
PyThreadState *tstate = nullptr;
try
{
std::string op_type = "%s";
platform::RecordEvent op_type_record_event("%s pybind_imperative_func");
%s
framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs("%s", args, %d, PyTuple_GET_SIZE(args) , attrs);
ConstructAttrMapFromPyArgs(op_type, args, %d, PyTuple_GET_SIZE(args) , attrs);
tstate = PyEval_SaveThread();
%s
imperative::NameVarBaseMap outs = %s;
imperative::NameVarBaseMap ins = %s;
%s
imperative::GetCurrentTracer()->TraceOp("%s", ins, outs, attrs, {%s});
imperative::GetCurrentTracer()->TraceOp(op_type, ins, outs, attrs, {%s});
PyEval_RestoreThread(tstate);
tstate = nullptr;
%s
Expand Down Expand Up @@ -208,8 +209,8 @@ std::string GenerateOpFunctionsBody(
const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = input.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, op_type,
in_name, arg_idx++, dispensable);
ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, in_name,
arg_idx++, dispensable);

if (input.dispensable()) {
const auto in_template = input.duplicable()
Expand Down Expand Up @@ -279,8 +280,8 @@ std::string GenerateOpFunctionsBody(
const auto in_cast_type =
output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, op_type,
out_name, arg_idx++, dispensable);
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, out_name,
arg_idx++, dispensable);
} else if (use_inplace_strategy && inplace_map.count(out_name)) {
PADDLE_ENFORCE_NE(
inplace_map[out_name], "",
Expand Down Expand Up @@ -329,7 +330,7 @@ std::string GenerateOpFunctionsBody(

auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str +=
paddle::string::Sprintf(CAST_SIZE_T_TEMPLATE, out_num_str, op_type,
paddle::string::Sprintf(CAST_SIZE_T_TEMPLATE, out_num_str,
out_num_str, arg_idx++, dispensable);
} else {
outs_initializer +=
Expand Down Expand Up @@ -375,11 +376,11 @@ std::string GenerateOpFunctionsBody(

// generate op funtcion body
auto op_function_str = paddle::string::Sprintf(
OP_FUNCTION_TEMPLATE, func_name, op_type, ins_cast_str, op_type,
OP_FUNCTION_TEMPLATE, func_name, op_type, op_type, ins_cast_str,
input_args_num, inplace_strategy_str, outs_initializer, ins_initializer,
ins_initializer_with_null + outs_initializer_with_null +
view_strategy_str,
op_type, inplace_mapping_str, return_str);
inplace_mapping_str, return_str);

return op_function_str;
}
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/core/compat/arg_map_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ struct KernelSignature {

KernelSignature& operator=(KernelSignature&& other) noexcept {
name = other.name;
input_names.swap(other.input_names);
attr_names.swap(other.attr_names);
output_names.swap(other.output_names);
input_names = std::move(other.input_names);
attr_names = std::move(other.attr_names);
output_names = std::move(other.output_names);
return *this;
}
};
Expand Down