From 2d639b87b57ee2d08e13dd1e065b742581f5f561 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Sun, 10 Sep 2023 12:58:07 +0000 Subject: [PATCH 1/2] Remove grad apis in c_ops --- .../fluid/pir/dialect/op_generator/api_gen.py | 23 +++++++------ .../fluid/pir/dialect/op_generator/op_gen.py | 4 +-- .../pir/dialect/op_generator/ops_api_gen.py | 32 ++++++++++--------- .../pir/dialect/op_generator/python_c_gen.py | 18 ++--------- .../pir/dialect/operator/ir/manual_api.cc | 23 +------------ .../pir/dialect/operator/ir/manual_api.h | 8 +---- paddle/phi/api/yaml/op_compat.yaml | 1 + 7 files changed, 38 insertions(+), 71 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/api_gen.py b/paddle/fluid/pir/dialect/op_generator/api_gen.py index 7d6787ef7707e..b0013431fcb6b 100644 --- a/paddle/fluid/pir/dialect/op_generator/api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/api_gen.py @@ -17,7 +17,12 @@ import re import yaml -from op_gen import OpCompatParser, OpInfoParser, to_pascal_case +from op_gen import ( + PD_MANUAL_OP_LIST, + OpCompatParser, + OpInfoParser, + to_pascal_case, +) H_FILE_TEMPLATE = """ @@ -81,7 +86,6 @@ OP_RESULT = 'pir::OpResult' VECTOR_TYPE = 'pir::VectorType' -PD_MANUAL_OP_LIST = ['add_n'] def get_op_class_name(op_name): @@ -111,6 +115,11 @@ def _parse_yaml(self, op_yaml_files, op_compat_yaml_file): ) return op_info_items + def _need_skip(self, op_info, op_name): + return ( + op_info.infer_meta_func is None and op_name not in PD_MANUAL_OP_LIST + ) + # ===================================== # Gen declare functions # ===================================== @@ -191,10 +200,7 @@ def _gen_h_file(self, op_info_items, namespaces, h_file_path): for op_name in op_info.op_phi_name: # NOTE:When infer_meta_func is None, the Build() function generated in pd_op # is wrong, so temporarily skip the automatic generation of these APIs - if ( - op_info.infer_meta_func is None - and op_name not in PD_MANUAL_OP_LIST - ): + if self._need_skip(op_info, op_name): continue declare_str += self._gen_one_declare(op_info, op_name, False) if len(op_info.mutable_attribute_name_list) > 0: @@ -325,10 +331,7 @@ def _gen_cpp_file(self, op_info_items, namespaces, cpp_file_path): for op_name in op_info.op_phi_name: # NOTE:When infer_meta_func is None, the Build() function generated in pd_op # is wrong, so temporarily skip the automatic generation of these APIs - if ( - op_info.infer_meta_func is None - and op_name not in PD_MANUAL_OP_LIST - ): + if self._need_skip(op_info, op_name): continue impl_str += self._gen_one_impl(op_info, op_name, False) if len(op_info.mutable_attribute_name_list) > 0: diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 2c84d0a404131..4ece8c08dfda9 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -173,7 +173,7 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ 'bool': 'pir::BoolAttribute', } -_NO_NEED_GEN_OPS = {'add_n', 'add_n_', 'add_n_with_kernel', 'split_grad'} +PD_MANUAL_OP_LIST = {'add_n', 'add_n_', 'add_n_with_kernel', 'split_grad'} def to_phi_and_fluid_op_name(op_item): @@ -881,7 +881,7 @@ def OpGenerator( # If op has inplace info, we will generate inplace op and non-inplace op. for op_name in op_info.op_phi_name: - if op_name in _NO_NEED_GEN_OPS: + if op_name in PD_MANUAL_OP_LIST: continue op_class_name = to_pascal_case(op_name) + "Op" op_dialect_name = dialect_name + "." + op_name diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index bde5c7c23a7bc..c8ed91332ff5d 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -15,7 +15,7 @@ import argparse import os -from api_gen import NAMESPACE_TEMPLATE, PD_MANUAL_OP_LIST, CodeGen +from api_gen import NAMESPACE_TEMPLATE, CodeGen CPP_FILE_TEMPLATE = """ #include @@ -55,7 +55,7 @@ }} }}""" -NO_DY_FUNCTION_IMPL_TEMPLATE = """ +STATIC_ONLY_FUNCTION_IMPL_TEMPLATE = """ static PyObject *{name}(PyObject *self, PyObject *args, PyObject *kwargs) {{ VLOG(6) << "Call static_api_{name}"; return static_api_{name}(self, args, kwargs); @@ -64,8 +64,9 @@ OPS_API_TEMPLATE = """ {{"{name}", (PyCFunction)(void (*)(void)){name}, METH_VARARGS | METH_KEYWORDS, "C++ interface function for {name}."}},""" -SPECIAL_STATIC_ONLY_APIS = [ - 'fetch', +NEED_GEN_STATIC_ONLY_APIS = ['fetch'] + +NO_NEED_GEN_STATIC_ONLY_APIS = [ 'set_value_with_tensor', 'set_value_with_tensor_', 'fused_bn_add_activation_', @@ -93,14 +94,18 @@ class OpsAPIGen(CodeGen): def __init__(self) -> None: super().__init__() + def _need_skip(self, op_info, op_name): + return ( + super()._need_skip(op_info, op_name) + or op_name.endswith('_grad') + or op_name.endswith('_grad_') + or op_name.endswith('xpu') + or op_name in NO_NEED_GEN_STATIC_ONLY_APIS + ) + def _gen_one_function_impl(self, name): - if ( - name.endswith('_grad') - or name.endswith('_grad_') - or name.endswith('xpu') - or name in SPECIAL_STATIC_ONLY_APIS - ): - return NO_DY_FUNCTION_IMPL_TEMPLATE.format(name=name) + if name in NEED_GEN_STATIC_ONLY_APIS: + return STATIC_ONLY_FUNCTION_IMPL_TEMPLATE.format(name=name) else: return FUNCTION_IMPL_TEMPLATE.format(name=name) @@ -117,10 +122,7 @@ def gen_cpp_file( ops_api_str = '' for op_info in op_info_items: for op_name in op_info.op_phi_name: - if ( - op_info.infer_meta_func is None - and op_name not in PD_MANUAL_OP_LIST - ): + if self._need_skip(op_info, op_name): continue function_impl_str += self._gen_one_function_impl(op_name) ops_api_str += self._gen_one_ops_api(op_name) diff --git a/paddle/fluid/pir/dialect/op_generator/python_c_gen.py b/paddle/fluid/pir/dialect/op_generator/python_c_gen.py index 10f20da3ffe2a..805a7d3750d83 100644 --- a/paddle/fluid/pir/dialect/op_generator/python_c_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/python_c_gen.py @@ -15,13 +15,7 @@ import argparse import re -from api_gen import ( - NAMESPACE_TEMPLATE, - OP_RESULT, - PD_MANUAL_OP_LIST, - VECTOR_TYPE, - CodeGen, -) +from api_gen import NAMESPACE_TEMPLATE, OP_RESULT, VECTOR_TYPE, CodeGen H_FILE_TEMPLATE = """ @@ -195,10 +189,7 @@ def _gen_h_file(self, op_info_items, namespaces, h_file_path): for op_name in op_info.op_phi_name: # NOTE:When infer_meta_func is None, the Build() function generated in pd_op # is wrong, so temporarily skip the automatic generation of these APIs - if ( - op_info.infer_meta_func is None - and op_name not in PD_MANUAL_OP_LIST - ): + if self._need_skip(op_info, op_name): continue declare_str += self._gen_one_declare(op_name) @@ -332,10 +323,7 @@ def _gen_cpp_file(self, op_info_items, namespaces, cpp_file_path): for op_name in op_info.op_phi_name: # NOTE:When infer_meta_func is None, the Build() function generated in pd_op # is wrong, so temporarily skip the automatic generation of these APIs - if ( - op_info.infer_meta_func is None - and op_name not in PD_MANUAL_OP_LIST - ): + if self._need_skip(op_info, op_name): continue impl_str += self._gen_one_impl(op_info, op_name) body = impl_str diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc index 5c3e107686dfd..a6b97e1ffaa0c 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc @@ -18,26 +18,5 @@ #include "paddle/pir/core/builtin_op.h" namespace paddle { -namespace dialect { -pir::OpResult split_grad(std::vector out_grads, - pir::OpResult axis) { - auto combine_op = - APIBuilder::Instance().GetBuilder()->Build(out_grads); - paddle::dialect::SplitGradOp split_grad_op = - APIBuilder::Instance().GetBuilder()->Build( - combine_op.out(), axis); - - return split_grad_op.x_grad(); -} - -pir::OpResult split_grad(std::vector out_grads, int axis) { - auto combine_op = - APIBuilder::Instance().GetBuilder()->Build(out_grads); - paddle::dialect::SplitGradOp split_grad_op = - APIBuilder::Instance().GetBuilder()->Build( - combine_op.out(), axis); - - return split_grad_op.x_grad(); -} -} // namespace dialect +namespace dialect {} // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.h b/paddle/fluid/pir/dialect/operator/ir/manual_api.h index b98746aa88454..de7c123ee1af5 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.h @@ -21,11 +21,5 @@ #include "paddle/pir/core/value.h" namespace paddle { -namespace dialect { - -pir::OpResult split_grad(std::vector out_grads, - pir::OpResult axis); - -pir::OpResult split_grad(std::vector out_grads, int axis); -} // namespace dialect +namespace dialect {} // namespace dialect } // namespace paddle diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 495ba53cd7613..8d5e6bf92799b 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -2638,6 +2638,7 @@ out : Out - op : split + backward : split_grad inputs: x : X outputs: From 927aa359b71ad3473e8b589e77930e4c3c65ac10 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Mon, 11 Sep 2023 06:49:03 +0000 Subject: [PATCH 2/2] Fix endswith --- paddle/fluid/pir/dialect/op_generator/ops_api_gen.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index c8ed91332ff5d..2d5bb98cbc093 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -97,9 +97,7 @@ def __init__(self) -> None: def _need_skip(self, op_info, op_name): return ( super()._need_skip(op_info, op_name) - or op_name.endswith('_grad') - or op_name.endswith('_grad_') - or op_name.endswith('xpu') + or op_name.endswith(('_grad', '_grad_', 'xpu')) or op_name in NO_NEED_GEN_STATIC_ONLY_APIS )