Skip to content

Commit

Permalink
[PHI CAPI] Add support for registering a new operator, PART2 (PaddleP…
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 authored and wz1qqx committed Jul 31, 2023
1 parent cabb36b commit 0daf1a2
Show file tree
Hide file tree
Showing 9 changed files with 882 additions and 1 deletion.
113 changes: 113 additions & 0 deletions paddle/fluid/framework/custom_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ limitations under the License. */
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/utils/any.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/capi/include/c_infer_meta_context.h"
#include "paddle/phi/capi/include/c_kernel_registry.h"
#include "paddle/phi/capi/include/c_meta_tensor.h"
#endif

#include "paddle/phi/api/include/operants_manager.h"
Expand Down Expand Up @@ -1226,3 +1230,112 @@ LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) {

} // namespace framework
} // namespace paddle

#ifdef PADDLE_WITH_CUSTOM_DEVICE
void PD_RegisterOperator(const char* kernel_name_cstr,
size_t in_nargs,
PD_KernelArgumentType* in_args_type,
size_t attr_nargs,
PD_KernelArgumentType* attr_args_type,
size_t out_nargs,
PD_KernelArgumentType* out_args_type,
void (*infer_shape_fn)(PD_InferMetaContext*)) {
std::string kernel_name(kernel_name_cstr);
if (infer_shape_fn &&
!paddle::framework::OpInfoMap::Instance().Has(kernel_name)) {
VLOG(8) << "Registering a new operator: " << kernel_name;

std::vector<std::string> op_inputs, op_outputs, op_attrs;

for (size_t i = 0; i < in_nargs; ++i) {
if (in_args_type[i] == PD_KernelArgumentType::PD_ARG_TYPE_TENSOR) {
op_inputs.push_back("Input_" + std::to_string(i));
} else if (in_args_type[i] ==
PD_KernelArgumentType::PD_ARG_TYPE_LIST_TENSOR) {
op_inputs.push_back("Input_" + std::to_string(i) +
paddle::kTensorVectorSuffix);
} else if (in_args_type[i] ==
PD_KernelArgumentType::PD_ARG_TYPE_OPTIONAL_TENSOR) {
op_inputs.push_back("Input_" + std::to_string(i) +
paddle::kOptionalSuffix);
} else {
op_inputs.push_back("Input_unknown");
}
}
for (size_t i = 0; i < out_nargs; ++i) {
if (out_args_type[i] == PD_KernelArgumentType::PD_ARG_TYPE_TENSOR) {
op_outputs.push_back("Output_" + std::to_string(i));
} else if (out_args_type[i] ==
PD_KernelArgumentType::PD_ARG_TYPE_LIST_TENSOR) {
op_outputs.push_back("Output_" + std::to_string(i) +
paddle::kTensorVectorSuffix);
} else {
op_outputs.push_back("Output_unknown");
}
}
for (size_t i = 0; i < attr_nargs; ++i) {
auto attr_type = attr_args_type[i];
if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_BOOL) {
op_attrs.push_back("Attr_" + std::to_string(i) + ":bool");
} else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_INT32) {
op_attrs.push_back("Attr_" + std::to_string(i) + ":int");
} else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_FLOAT32) {
op_attrs.push_back("Attr_" + std::to_string(i) + ":float");
} else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_INT64) {
op_attrs.push_back("Attr_" + std::to_string(i) + ":int64_t");
} else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_STRING) {
op_attrs.push_back("Attr_" + std::to_string(i) + ":std::string");
} else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_LIST_INT32) {
op_attrs.push_back("Attr_" + std::to_string(i) + ":std::vector<int>");
} else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_LIST_FLOAT32) {
op_attrs.push_back("Attr_" + std::to_string(i) + ":std::vector<float>");
} else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_LIST_INT64) {
op_attrs.push_back("Attr_" + std::to_string(i) +
":std::vector<int64_t>");
} else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_LIST_STRING) {
op_attrs.push_back("Attr_" + std::to_string(i) +
":std::vector<std::string>");
} else {
op_attrs.push_back("Attr_unknown");
}
}

paddle::framework::OpInfo info;
// Op
info.creator_ = [](const std::string& op_name,
const paddle::framework::VariableNameMap& inputs,
const paddle::framework::VariableNameMap& outputs,
const paddle::framework::AttributeMap& attrs) {
return new paddle::framework::OperatorWithKernel(
op_name, inputs, outputs, attrs);
};

// OpMaker
info.proto_ = new paddle::framework::proto::OpProto;
info.proto_->set_type(kernel_name);

info.checker_ = new paddle::framework::OpAttrChecker();

paddle::framework::CustomOpMaker custom_maker(
op_inputs, op_outputs, op_attrs);
custom_maker(info.proto_, info.checker_);
PADDLE_ENFORCE_EQ(
info.proto_->IsInitialized(),
true,
phi::errors::PreconditionNotMet(
"Fail to initialize %s's OpProto, because %s is not initialized.",
kernel_name,
info.proto_->InitializationErrorString()));

info.infer_shape_ = [infer_shape_fn, kernel_name](
paddle::framework::InferShapeContext* ctx) {
auto infer_meta_context =
paddle::framework::BuildInferMetaContext(ctx, kernel_name);
infer_shape_fn(
reinterpret_cast<PD_InferMetaContext*>(&infer_meta_context));
};

paddle::framework::OpInfoMap::Instance().Insert(kernel_name, info);
}
}
#endif
2 changes: 2 additions & 0 deletions paddle/phi/capi/all.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

#include "paddle/phi/capi/include/c_data_type.h"
#include "paddle/phi/capi/include/c_device_context.h"
#include "paddle/phi/capi/include/c_infer_meta_context.h"
#include "paddle/phi/capi/include/c_int_array.h"
#include "paddle/phi/capi/include/c_kernel_context.h"
#include "paddle/phi/capi/include/c_kernel_factory.h"
#include "paddle/phi/capi/include/c_kernel_registry.h"
#include "paddle/phi/capi/include/c_meta_tensor.h"
#include "paddle/phi/capi/include/c_place.h"
#include "paddle/phi/capi/include/c_scalar.h"
#include "paddle/phi/capi/include/c_tensor.h"
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/capi/capi.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ PD_DECLARE_CAPI(int_array);
PD_DECLARE_CAPI(kernel_context);
PD_DECLARE_CAPI(kernel_factory);
PD_DECLARE_CAPI(kernel_registry);
PD_DECLARE_CAPI(infer_meta_context);
PD_DECLARE_CAPI(meta_tensor);
PD_DECLARE_CAPI(place);
PD_DECLARE_CAPI(scalar);
PD_DECLARE_CAPI(tensor);
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/capi/include/c_kernel_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <vector>

#include "paddle/phi/capi/include/c_data_type.h"
#include "paddle/phi/capi/include/c_infer_meta_context.h"
#include "paddle/phi/capi/include/c_kernel_context.h"
#include "paddle/phi/capi/include/c_kernel_factory.h"

Expand Down Expand Up @@ -71,6 +72,15 @@ void PD_RegisterPhiKernel(const char *kernel_name_cstr,
void (*fn)(PD_KernelContext *),
void *variadic_kernel_fn);

void PD_RegisterOperator(const char *kernel_name_cstr,
size_t in_nargs,
PD_KernelArgumentType *in_args_type,
size_t attr_nargs,
PD_KernelArgumentType *attr_args_type,
size_t out_nargs,
PD_KernelArgumentType *out_args_type,
void (*infer_shape_fn)(PD_InferMetaContext *));

#ifdef __cplusplus
} // extern "C"
#endif
Expand Down
Loading

0 comments on commit 0daf1a2

Please sign in to comment.