Skip to content

Commit

Permalink
【pir】Modify comment of pr57478 and pr56873 (PaddlePaddle#57520)
Browse files Browse the repository at this point in the history
* tmp

* reply comment

* code style
  • Loading branch information
xiaoguoguo626807 authored Sep 21, 2023
1 parent 079dadf commit 975c09c
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 25 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/op_generator/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def _gen_api_inputs(self, op_info):
assert len(name_list) == len(type_list)
ret = []
for name, type in zip(name_list, type_list):
ret.append(f'{self._type_map[type]} {name}')
ret.append(f'const {self._type_map[type]}& {name}')
return ', '.join(ret)

def _gen_api_attrs(
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/op_generator/python_c_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@
"""

BUILTIN_STACK_OP_TEMPLATE = """
{name} = paddle::dialect::stack({name}_tmp, 0);
{name} = paddle::dialect::stack({name}_tmp, /*axis*/0);
"""
TYPE_TO_FUNC_MAP = {
"bool": "CastPyArg2Boolean",
Expand Down
23 changes: 12 additions & 11 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ pir::OpResult builtin_combine(const std::vector<pir::Value>& x) {
return combine_op.out();
}

std::vector<pir::OpResult> add_n_grad(std::vector<pir::Value> inputs,
pir::Value out_grad) {
std::vector<pir::OpResult> add_n_grad(const std::vector<pir::Value>& inputs,
const pir::Value& out_grad) {
std::vector<pir::OpResult> inputs_grad;
for (size_t i = 0; i < inputs.size(); i++) {
paddle::dialect::ScaleOp scale_op =
Expand All @@ -40,8 +40,8 @@ std::vector<pir::OpResult> add_n_grad(std::vector<pir::Value> inputs,
return inputs_grad;
}

pir::OpResult zeros_like(pir::Value x,
phi::DataType dtype,
pir::OpResult zeros_like(const pir::Value& x,
const phi::DataType dtype,
const Place& place) {
return paddle::dialect::full_like(x, 0, dtype, place);
}
Expand All @@ -54,17 +54,17 @@ pir::OpResult get_parameter(const std::string& name) {
return get_parameter_op.result(0);
}

void set_parameter(pir::Value parameter, const std::string& name) {
void set_parameter(const pir::Value& parameter, const std::string& name) {
std::unique_ptr<pir::Parameter> param(
new pir::Parameter(nullptr, 0, parameter.type()));
APIBuilder::Instance().SetParameter(name, std::move(param));
APIBuilder::Instance().GetBuilder()->Build<pir::SetParameterOp>(parameter,
name);
}

pir::OpResult embedding_grad(pir::Value x,
pir::Value weight,
pir::Value out_grad,
pir::OpResult embedding_grad(const pir::Value& x,
const pir::Value& weight,
const pir::Value& out_grad,
int64_t padding_idx,
bool sparse) {
if (weight.type().isa<paddle::dialect::DenseTensorType>()) {
Expand All @@ -81,7 +81,8 @@ pir::OpResult embedding_grad(pir::Value x,
}
}

pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad, int axis) {
pir::OpResult split_with_num_grad(const std::vector<pir::Value>& out_grad,
int axis) {
auto out_grad_combine_op =
APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(out_grad);
paddle::dialect::SplitGradOp split_grad_op =
Expand All @@ -90,8 +91,8 @@ pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad, int axis) {
return split_grad_op.result(0);
}

pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad,
pir::Value axis) {
pir::OpResult split_with_num_grad(const std::vector<pir::Value>& out_grad,
const pir::Value& axis) {
auto out_grad_combine_op =
APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(out_grad);
paddle::dialect::SplitGradOp split_grad_op =
Expand Down
21 changes: 11 additions & 10 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,27 @@ namespace dialect {

pir::OpResult builtin_combine(const std::vector<pir::Value>& x);

std::vector<pir::OpResult> add_n_grad(std::vector<pir::Value> inputs,
pir::Value out_grad);
std::vector<pir::OpResult> add_n_grad(const std::vector<pir::Value>& inputs,
const pir::Value& out_grad);

pir::OpResult zeros_like(pir::Value x,
pir::OpResult zeros_like(const pir::Value& x,
phi::DataType dtype = phi::DataType::UNDEFINED,
const Place& place = {});

pir::OpResult get_parameter(const std::string& name);

void set_parameter(pir::Value parameter, const std::string& name);
void set_parameter(const pir::Value& parameter, const std::string& name);

pir::OpResult embedding_grad(pir::Value x,
pir::Value weight,
pir::Value out_grad,
pir::OpResult embedding_grad(const pir::Value& x,
const pir::Value& weight,
const pir::Value& out_grad,
int64_t padding_idx = -1,
bool sparse = false);

pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad, int axis);
pir::OpResult split_with_num_grad(const std::vector<pir::Value>& out_grad,
int axis);

pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad,
pir::Value axis);
pir::OpResult split_with_num_grad(const std::vector<pir::Value>& out_grad,
const pir::Value& axis);
} // namespace dialect
} // namespace paddle
4 changes: 3 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ std::vector<std::vector<pir::OpResult>> AddNOp::Vjp(
AddNOp op_obj = op->dyn_cast<AddNOp>();

VLOG(6) << "Prepare inputs of add_n_grad";

PADDLE_ENFORCE(
op_obj.inputs() != nullptr,
paddle::platform::errors::Fatal("addn op's inputs can't be null"));
pir::CombineOp combine_op_obj = op_obj.inputs()
.dyn_cast<pir::OpResult>()
.owner()
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/primitive/backend/manual/manual_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include <vector>

#include "paddle/phi/api/include/tensor.h"
#include "paddle/utils/optional.h"

namespace paddle {
namespace primitive {
Expand Down

0 comments on commit 975c09c

Please sign in to comment.