Skip to content

Commit

Permalink
[CustomOp] Support attributes as func input in custom op (PaddlePaddl…
Browse files Browse the repository at this point in the history
…e#31128)

* add simple attr support and test

* add int, float attr support

* support other attribute

* add custom attrs test in cmake

* polish details

* fix test failed

* add backward test

* update test flags
  • Loading branch information
chenwhql committed Feb 26, 2021
1 parent d4ffeb3 commit db67746
Show file tree
Hide file tree
Showing 7 changed files with 458 additions and 43 deletions.
68 changes: 48 additions & 20 deletions paddle/fluid/extension/include/op_meta_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,26 @@ inline std::string Grad(const std::string& var_name) {
using KernelFunc = std::vector<Tensor> (*)(std::vector<Tensor> inputs,
std::vector<boost::any> attrs);

#define PD_SPECIALIZE_ComputeCallHelper(attr_type) \
template <typename... Tail> \
struct ComputeCallHelper<attr_type, Tail...> { \
template <int in_idx, int attr_idx, typename... PreviousArgs> \
static Return Compute(std::vector<Tensor> inputs, \
std::vector<boost::any> attrs, \
const PreviousArgs&... pargs) { \
try { \
attr_type arg = boost::any_cast<attr_type>(attrs[attr_idx]); \
return ComputeCallHelper<Tail...>::template Compute<in_idx, \
attr_idx + 1>( \
inputs, attrs, pargs..., arg); \
} catch (boost::bad_any_cast&) { \
PD_THROW( \
"Attribute cast error in custom operator. Expected " #attr_type \
" value."); \
} \
} \
}

template <typename T>
struct TypeTag {};

Expand Down Expand Up @@ -114,26 +134,20 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
}
};

// TODO(chenweihang): add support for attribute input
// int attribute input (not used now)
template <typename... Tail>
struct ComputeCallHelper<int, Tail...> {
template <int in_idx, int attr_idx, typename... PreviousArgs>
static Return Compute(std::vector<Tensor> inputs,
std::vector<boost::any> attrs,
const PreviousArgs&... pargs) {
try {
int arg = boost::any_cast<int>(attrs[attr_idx]);
return ComputeCallHelper<Tail...>::template Compute<in_idx,
attr_idx + 1>(
inputs, attrs, pargs..., arg);
} catch (boost::bad_any_cast&) {
PD_THROW(
"Attribute cast error in custom operator. Expected int value.");
}
}
};

PD_SPECIALIZE_ComputeCallHelper(bool);
PD_SPECIALIZE_ComputeCallHelper(int);
PD_SPECIALIZE_ComputeCallHelper(float);
PD_SPECIALIZE_ComputeCallHelper(int64_t);
PD_SPECIALIZE_ComputeCallHelper(std::string);
PD_SPECIALIZE_ComputeCallHelper(std::vector<int>);
PD_SPECIALIZE_ComputeCallHelper(std::vector<float>);
PD_SPECIALIZE_ComputeCallHelper(std::vector<int64_t>);
PD_SPECIALIZE_ComputeCallHelper(std::vector<std::string>);
// TODO(chenweihang): support other attribute type if needed.
// Why not support other attribute type here?
// - boost::blank, std::vector<bool> and std::vector<double>
// are not used in op
// - BlockDesc* and std::vector<BlockDesc*> are used in framework
// end: base template
template <typename T>
struct ComputeCallHelper<TypeTag<T>> {
Expand Down Expand Up @@ -245,10 +259,23 @@ struct InferDtypeFuncImpl<Return (*)(Args...), impl_fn> {
class PD_DLL_DECL OpMetaInfo {
public:
explicit OpMetaInfo(const std::string& op_name) : name_(op_name) {}

// format: {"<name1>", "<name2>", ...}
OpMetaInfo& Inputs(std::vector<std::string>&& inputs);

// format: {"<name1>", "<name2>", ...}
OpMetaInfo& Outputs(std::vector<std::string>&& outputs);

// format: {"<name1>:<type1>", "<name1>:<type1>", ...}
OpMetaInfo& Attrs(std::vector<std::string>&& attrs);

// format: PD_KERNEL(...)
OpMetaInfo& SetKernelFn(KernelFunc&& func);

// format: PD_INFER_SHAPE(...)
OpMetaInfo& SetInferShapeFn(InferShapeFunc&& func);

// format: PD_INFER_DTYPE(...)
OpMetaInfo& SetInferDtypeFn(InferDtypeFunc&& func);

private:
Expand Down Expand Up @@ -297,6 +324,7 @@ class PD_DLL_DECL OpMetaInfoBuilder {
explicit OpMetaInfoBuilder(std::string&& name);
OpMetaInfoBuilder& Inputs(std::vector<std::string>&& inputs);
OpMetaInfoBuilder& Outputs(std::vector<std::string>&& outputs);
OpMetaInfoBuilder& Attrs(std::vector<std::string>&& attrs);
OpMetaInfoBuilder& SetKernelFn(KernelFunc func);
OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func);
OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func);
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/extension/src/op_meta_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ OpMetaInfo& OpMetaInfo::Outputs(std::vector<std::string>&& outputs) {
outputs_ = std::forward<std::vector<std::string>>(outputs);
return *this;
}
OpMetaInfo& OpMetaInfo::Attrs(std::vector<std::string>&& attrs) {
attrs_ = std::forward<std::vector<std::string>>(attrs);
return *this;
}
OpMetaInfo& OpMetaInfo::SetKernelFn(KernelFunc&& func) {
kernel_fn_ = std::forward<KernelFunc>(func);
return *this;
Expand Down Expand Up @@ -78,6 +82,11 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::Outputs(
return *this;
}

OpMetaInfoBuilder& OpMetaInfoBuilder::Attrs(std::vector<std::string>&& attrs) {
info_ptr_->Attrs(std::forward<std::vector<std::string>>(attrs));
return *this;
}

OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) {
info_ptr_->SetKernelFn(std::forward<KernelFunc>(func));
return *this;
Expand Down
132 changes: 118 additions & 14 deletions paddle/fluid/framework/custom_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,24 @@ inline bool IsMemberOf(const std::vector<std::string>& vec,
return std::find(vec.cbegin(), vec.cend(), name) != vec.cend();
}

std::vector<std::string> ParseAttrStr(const std::string& attr) {
auto split_pos = attr.find_first_of(":");
PADDLE_ENFORCE_NE(split_pos, std::string::npos,
platform::errors::InvalidArgument(
"Invalid attribute string format. Attribute string "
"format is `<name>:<type>`."));

std::vector<std::string> rlt;
// 1. name
rlt.emplace_back(string::trim_spaces(attr.substr(0, split_pos)));
// 2. type
rlt.emplace_back(string::trim_spaces(attr.substr(split_pos + 1)));

VLOG(1) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1];

return rlt;
}

} // namespace detail

////////////////// Kernel Define ////////////////////
Expand All @@ -81,7 +99,8 @@ inline bool IsMemberOf(const std::vector<std::string>& vec,
static void RunKernelFunc(const framework::ExecutionContext& ctx,
const paddle::KernelFunc& func,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
VLOG(1) << "Custom Operator: Start run KernelFunc.";
std::vector<paddle::Tensor> custom_ins;
for (auto& in_name : inputs) {
Expand All @@ -98,10 +117,43 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
custom_ins.emplace_back(custom_in);
}

std::vector<boost::any> attrs;
std::vector<boost::any> custom_attrs;
for (auto& attr_str : attrs) {
auto attr_name_and_type = detail::ParseAttrStr(attr_str);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
custom_attrs.emplace_back(ctx.Attr<bool>(attr_name));
} else if (attr_type_str == "int") {
custom_attrs.emplace_back(ctx.Attr<int>(attr_name));
} else if (attr_type_str == "float") {
custom_attrs.emplace_back(ctx.Attr<float>(attr_name));
} else if (attr_type_str == "int64_t") {
custom_attrs.emplace_back(ctx.Attr<int64_t>(attr_name));
} else if (attr_type_str == "std::string") {
custom_attrs.emplace_back(ctx.Attr<std::string>(attr_name));
} else if (attr_type_str == "std::vector<int>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<int>>(attr_name));
} else if (attr_type_str == "std::vector<float>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<float>>(attr_name));
} else if (attr_type_str == "std::vector<int64_t>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<int64_t>>(attr_name));
} else if (attr_type_str == "std::vector<std::string>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<std::string>>(attr_name));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<int64_t>, "
"`std::vector<std::string>`, Please check whether "
"the attribute data type and data type string are matched.",
attr_type_str));
}
}

VLOG(1) << "Run ComputeFunc.";
auto outs = func(custom_ins, attrs);
auto outs = func(custom_ins, custom_attrs);

VLOG(1) << "Custom Operator: Share outputs into ExecutionContext.";
for (size_t i = 0; i < outputs.size(); ++i) {
Expand Down Expand Up @@ -164,7 +216,51 @@ class CustomOpMaker : public OpProtoAndCheckerMaker {
for (auto& out_name : outputs_) {
AddOutput(out_name, "The output " + out_name + "of Custom Operator.");
}
// TODO(chenweihang): support attrs in later PR
for (auto& attr : attrs_) {
auto attr_name_and_type = detail::ParseAttrStr(attr);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
AddAttr<bool>(attr_name, "custom operator bool attribute.")
.SetDefault(false);
} else if (attr_type_str == "int") {
AddAttr<int>(attr_name, "custom operator int attribute.").SetDefault(1);
} else if (attr_type_str == "float") {
AddAttr<float>(attr_name, "custom operator float attribute.")
.SetDefault(1.0f);
} else if (attr_type_str == "int64_t") {
AddAttr<int64_t>(attr_name, "custom operator int64_t attribute.")
.SetDefault(1);
} else if (attr_type_str == "std::string") {
AddAttr<std::string>(attr_name, "custom operator int attribute.")
.SetDefault("");
} else if (attr_type_str == "std::vector<int>") {
AddAttr<std::vector<int>>(attr_name,
"custom operator std::vector<int> attribute.")
.SetDefault({});
} else if (attr_type_str == "std::vector<float>") {
AddAttr<std::vector<float>>(
attr_name, "custom operator std::vector<float> attribute.")
.SetDefault({});
} else if (attr_type_str == "std::vector<int64_t>") {
AddAttr<std::vector<int64_t>>(
attr_name, "custom operator std::vector<int64_t> attribute.")
.SetDefault({});
} else if (attr_type_str == "std::vector<std::string>") {
AddAttr<std::vector<std::string>>(
attr_name, "custom operator std::vector<std::string> attribute.")
.SetDefault({});
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<int64_t>, "
"`std::vector<std::string>`, Please check whether "
"the attribute data type and data type string are matched.",
attr_type_str));
}
}
AddComment(R"DOC(
Custom Operator.
Expand Down Expand Up @@ -227,7 +323,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
VLOG(1) << "Custom Operator: GradOpDescMaker - output: " << out_name;
grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name)));
}
// TODO(chenweihang): support attrs in later PR
grad_op->SetAttrMap(this->Attrs());
}

private:
Expand Down Expand Up @@ -287,7 +383,7 @@ class CustomGradOpMaker<imperative::OpBase>
VLOG(1) << "Custom Operator: GradOpBaseMaker - output: " << out_name;
grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name)));
}
// TODO(chenweihang): support attrs in later PR
grad_op->SetAttrMap(this->Attrs());
}

private:
Expand All @@ -303,31 +399,36 @@ void RegisterOperatorKernelWithPlace(const std::string& name,
const proto::VarType::Type type,
const PlaceType& place,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
OpKernelType key(type,
CustomTensorUtils::ConvertEnumPlaceToInnerPlace(place));
VLOG(1) << "Custom Operator: op kernel key: " << key;
OperatorWithKernel::AllOpKernels()[name][key] =
[kernel_func, inputs, outputs](const framework::ExecutionContext& ctx) {
[kernel_func, inputs, outputs,
attrs](const framework::ExecutionContext& ctx) {
VLOG(1) << "Custom Operator: run custom kernel func in lambda.";
RunKernelFunc(ctx, kernel_func, inputs, outputs);
RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs);
};
}

void RegisterOperatorKernel(const std::string& name,
const paddle::KernelFunc& kernel_func,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
VLOG(1) << "Custom Operator: op name in kernel: " << name;
// NOTE [ Dummy Op Kernel Key ]
// TODO(chenweihang): Because execute engine need get device context based
// op_kernel_key.place_, so we should register kernel for each
// device. But this is not entirely correct, if user only give a cpu kernel,
// but call api in gpu device, it will cause error.
RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW,
PlaceType::kCPU, inputs, outputs);
PlaceType::kCPU, inputs, outputs, attrs);
#ifdef PADDLE_WITH_CUDA
RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW,
PlaceType::kGPU, inputs, outputs);
PlaceType::kGPU, inputs, outputs, attrs);
#endif
}

void RegisterOperatorWithMetaInfo(
Expand All @@ -350,6 +451,8 @@ void RegisterOperatorWithMetaInfo(
<< string::join_strings(op_inputs, ',');
VLOG(1) << "Custom Operator: forward, op outputs: "
<< string::join_strings(op_outputs, ',');
VLOG(1) << "Custom Operator: forward, op attrs: "
<< string::join_strings(op_attrs, ',');

// Op
info.creator_ = [](const std::string& op_name, const VariableNameMap& inputs,
Expand Down Expand Up @@ -426,7 +529,7 @@ void RegisterOperatorWithMetaInfo(
};

// Kernel func
RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs);
RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs, op_attrs);

// If grad op or double grad op exists
std::string cur_op_name = op_name;
Expand All @@ -436,6 +539,7 @@ void RegisterOperatorWithMetaInfo(
auto& grad_op_name = OpMetaInfoHelper::GetOpName(cur_grad_op);
auto& grad_op_inputs = OpMetaInfoHelper::GetInputs(cur_grad_op);
auto& grad_op_outputs = OpMetaInfoHelper::GetOutputs(cur_grad_op);
auto& grad_op_attrs = OpMetaInfoHelper::GetAttrs(cur_grad_op);
auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op);

VLOG(1) << "Custom Operator: backward, op name: " << grad_op_name;
Expand Down Expand Up @@ -489,7 +593,7 @@ void RegisterOperatorWithMetaInfo(

// Kernel func
RegisterOperatorKernel(grad_op_name, grad_kernel_fn, grad_op_inputs,
grad_op_outputs);
grad_op_outputs, grad_op_attrs);

// update current info
OpInfoMap::Instance().Insert(cur_op_name, info);
Expand Down
7 changes: 5 additions & 2 deletions python/paddle/fluid/tests/custom_op/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@ py_test(test_sysconfig SRCS test_sysconfig.py)

# 'test_dispatch' compile .cc file
py_test(test_dispatch_jit SRCS test_dispatch_jit.py)
set_tests_properties(test_dispatch_jit PROPERTIES TIMEOUT 180)
set_tests_properties(test_dispatch_jit PROPERTIES TIMEOUT 120)

py_test(test_multi_out_jit SRCS test_multi_out_jit.py)
set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 180)
set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 120)

py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py)
set_tests_properties(test_custom_attrs_jit PROPERTIES TIMEOUT 120)

if(NOT LINUX)
return()
Expand Down
Loading

0 comments on commit db67746

Please sign in to comment.