From 975c09c52b0746b5b2c90d42f9f2867892f31254 Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Thu, 21 Sep 2023 18:35:15 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90pir=E3=80=91Modify=20comment=20of=20pr?= =?UTF-8?q?57478=20and=20pr56873=20(#57520)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * tmp * reply comment * code style --- .../fluid/pir/dialect/op_generator/api_gen.py | 2 +- .../pir/dialect/op_generator/python_c_gen.py | 2 +- .../pir/dialect/operator/ir/manual_api.cc | 23 ++++++++++--------- .../pir/dialect/operator/ir/manual_api.h | 21 +++++++++-------- .../pir/dialect/operator/ir/manual_op_vjp.cc | 4 +++- .../primitive/backend/manual/manual_backend.h | 1 - 6 files changed, 28 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/api_gen.py b/paddle/fluid/pir/dialect/op_generator/api_gen.py index d7e74f72b652f3..851f318e9bc47c 100644 --- a/paddle/fluid/pir/dialect/op_generator/api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/api_gen.py @@ -150,7 +150,7 @@ def _gen_api_inputs(self, op_info): assert len(name_list) == len(type_list) ret = [] for name, type in zip(name_list, type_list): - ret.append(f'{self._type_map[type]} {name}') + ret.append(f'const {self._type_map[type]}& {name}') return ', '.join(ret) def _gen_api_attrs( diff --git a/paddle/fluid/pir/dialect/op_generator/python_c_gen.py b/paddle/fluid/pir/dialect/op_generator/python_c_gen.py index 440f656b999641..adb5270e975e60 100644 --- a/paddle/fluid/pir/dialect/op_generator/python_c_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/python_c_gen.py @@ -174,7 +174,7 @@ """ BUILTIN_STACK_OP_TEMPLATE = """ - {name} = paddle::dialect::stack({name}_tmp, 0); + {name} = paddle::dialect::stack({name}_tmp, /*axis*/0); """ TYPE_TO_FUNC_MAP = { "bool": "CastPyArg2Boolean", diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc index 24e7a94b666503..eb5acbf2388ea8 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc @@ -28,8 +28,8 @@ pir::OpResult builtin_combine(const std::vector& x) { return combine_op.out(); } -std::vector add_n_grad(std::vector inputs, - pir::Value out_grad) { +std::vector add_n_grad(const std::vector& inputs, + const pir::Value& out_grad) { std::vector inputs_grad; for (size_t i = 0; i < inputs.size(); i++) { paddle::dialect::ScaleOp scale_op = @@ -40,8 +40,8 @@ std::vector add_n_grad(std::vector inputs, return inputs_grad; } -pir::OpResult zeros_like(pir::Value x, - phi::DataType dtype, +pir::OpResult zeros_like(const pir::Value& x, + const phi::DataType dtype, const Place& place) { return paddle::dialect::full_like(x, 0, dtype, place); } @@ -54,7 +54,7 @@ pir::OpResult get_parameter(const std::string& name) { return get_parameter_op.result(0); } -void set_parameter(pir::Value parameter, const std::string& name) { +void set_parameter(const pir::Value& parameter, const std::string& name) { std::unique_ptr param( new pir::Parameter(nullptr, 0, parameter.type())); APIBuilder::Instance().SetParameter(name, std::move(param)); @@ -62,9 +62,9 @@ void set_parameter(pir::Value parameter, const std::string& name) { name); } -pir::OpResult embedding_grad(pir::Value x, - pir::Value weight, - pir::Value out_grad, +pir::OpResult embedding_grad(const pir::Value& x, + const pir::Value& weight, + const pir::Value& out_grad, int64_t padding_idx, bool sparse) { if (weight.type().isa()) { @@ -81,7 +81,8 @@ pir::OpResult embedding_grad(pir::Value x, } } -pir::OpResult split_with_num_grad(std::vector out_grad, int axis) { +pir::OpResult split_with_num_grad(const std::vector& out_grad, + int axis) { auto out_grad_combine_op = APIBuilder::Instance().GetBuilder()->Build(out_grad); paddle::dialect::SplitGradOp split_grad_op = @@ -90,8 +91,8 @@ pir::OpResult split_with_num_grad(std::vector out_grad, int axis) { return split_grad_op.result(0); } -pir::OpResult split_with_num_grad(std::vector out_grad, - pir::Value axis) { +pir::OpResult split_with_num_grad(const std::vector& out_grad, + const pir::Value& axis) { auto out_grad_combine_op = APIBuilder::Instance().GetBuilder()->Build(out_grad); paddle::dialect::SplitGradOp split_grad_op = diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.h b/paddle/fluid/pir/dialect/operator/ir/manual_api.h index c919448f1ddb0e..fe579295ad5a09 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.h @@ -25,26 +25,27 @@ namespace dialect { pir::OpResult builtin_combine(const std::vector& x); -std::vector add_n_grad(std::vector inputs, - pir::Value out_grad); +std::vector add_n_grad(const std::vector& inputs, + const pir::Value& out_grad); -pir::OpResult zeros_like(pir::Value x, +pir::OpResult zeros_like(const pir::Value& x, phi::DataType dtype = phi::DataType::UNDEFINED, const Place& place = {}); pir::OpResult get_parameter(const std::string& name); -void set_parameter(pir::Value parameter, const std::string& name); +void set_parameter(const pir::Value& parameter, const std::string& name); -pir::OpResult embedding_grad(pir::Value x, - pir::Value weight, - pir::Value out_grad, +pir::OpResult embedding_grad(const pir::Value& x, + const pir::Value& weight, + const pir::Value& out_grad, int64_t padding_idx = -1, bool sparse = false); -pir::OpResult split_with_num_grad(std::vector out_grad, int axis); +pir::OpResult split_with_num_grad(const std::vector& out_grad, + int axis); -pir::OpResult split_with_num_grad(std::vector out_grad, - pir::Value axis); +pir::OpResult split_with_num_grad(const std::vector& out_grad, + const pir::Value& axis); } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc index b6d131e5411fbc..80c13ac89def13 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc @@ -34,7 +34,9 @@ std::vector> AddNOp::Vjp( AddNOp op_obj = op->dyn_cast(); VLOG(6) << "Prepare inputs of add_n_grad"; - + PADDLE_ENFORCE( + op_obj.inputs() != nullptr, + paddle::platform::errors::Fatal("addn op's inputs can't be null")); pir::CombineOp combine_op_obj = op_obj.inputs() .dyn_cast() .owner() diff --git a/paddle/fluid/primitive/backend/manual/manual_backend.h b/paddle/fluid/primitive/backend/manual/manual_backend.h index 16c1facbd5354c..3c9340164ac012 100644 --- a/paddle/fluid/primitive/backend/manual/manual_backend.h +++ b/paddle/fluid/primitive/backend/manual/manual_backend.h @@ -18,7 +18,6 @@ #include #include "paddle/phi/api/include/tensor.h" -#include "paddle/utils/optional.h" namespace paddle { namespace primitive {