Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PRIM][PIR] Support all vjp gen #58023

Merged
merged 9 commits into from
Oct 12, 2023
6 changes: 3 additions & 3 deletions paddle/fluid/operators/generator/type_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
'int64_t[]': 'const std::vector<int64_t>&',
'float[]': 'const std::vector<float>&',
'double[]': 'const std::vector<double>&',
'str[]': 'const std::vector<<std::string>&',
'str[]': 'const std::vector<std::string>&',
}

opmaker_attr_types_map = {
Expand Down Expand Up @@ -86,8 +86,8 @@
}

optional_output_type_map = {
'Tensor': 'const paddle::optional<Tensor>&',
'Tensor[]': 'const paddle::optional<std::vector<Tensor>>&',
'Tensor': 'const paddle::optional<Tensor>',
'Tensor[]': 'const paddle::optional<std::vector<Tensor>>',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么把引用去掉了,但const 还在?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢提醒,为了不阻塞大家工作,下个pr修改一下

}

# ------------------------------ phi attr ------------------------------
Expand Down
22 changes: 12 additions & 10 deletions paddle/fluid/pir/dialect/op_generator/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def _is_optional_input(self, op_info, input_name):
return True
return False

def _is_optinonal_output(self, op_info, output_name):
def _is_optional_output(self, op_info, op_name, output_name):
if op_name.endswith(('_grad', '_grad_')):
return False
inplace_map = op_info.inplace_map
input_optional_list = op_info.input_optional_list
input_name_list = op_info.input_name_list
Expand Down Expand Up @@ -271,7 +273,7 @@ def _gen_api_args(
)
return (inputs + ', ' + attrs).strip(', ')

def _gen_ret_type(self, op_info):
def _gen_ret_type(self, op_info, op_name):
name_list = op_info.output_name_list
type_list = op_info.output_type_list
intermediate_list = op_info.output_intermediate_list
Expand All @@ -285,15 +287,15 @@ def _gen_ret_type(self, op_info):
):
if intermediate == 'true':
continue
if self._is_optinonal_output(op_info, name):
if self._is_optional_output(op_info, op_name, name):
ret.append(OPTIONAL_OUTPUT_TYPE_MAP[type])
else:
ret.append(OUTPUT_TYPE_MAP[type])
return 'std::tuple<{}>'.format(', '.join(ret))
elif output_num == 1:
index = intermediate_list.index('false')
name = name_list[index]
if self._is_optinonal_output(op_info, name):
if self._is_optional_output(op_info, op_name, name):
return OPTIONAL_OUTPUT_TYPE_MAP[type_list[index]]
else:
return OUTPUT_TYPE_MAP[type_list[index]]
Expand All @@ -304,7 +306,7 @@ def _gen_one_declare(
self, op_info, op_name, is_mutable_attr, is_vector_mutable_attr
):
return API_DECLARE_TEMPLATE.format(
ret_type=self._gen_ret_type(op_info),
ret_type=self._gen_ret_type(op_info, op_name),
api_name=op_name,
args=self._gen_api_args(
op_info, True, is_mutable_attr, is_vector_mutable_attr
Expand Down Expand Up @@ -367,7 +369,7 @@ def _gen_handle_optional_outputs(self, op_info, op_name):
):
if intermediate == 'true':
continue
if self._is_optinonal_output(op_info, name):
if self._is_optional_output(op_info, op_name, name):
if VECTOR_TYPE in type:
ret += OPTIONAL_VECTOR_OPRESULT_OUTPUT_TEMPLATE.format(
name=name,
Expand Down Expand Up @@ -461,7 +463,7 @@ def _gen_compute_op(
op_inst_name,
)

def _gen_out_split_and_ret_list(self, op_info, op_inst_name):
def _gen_out_split_and_ret_list(self, op_info, op_name, op_inst_name):
name_list = op_info.output_name_list
type_list = op_info.output_type_list
intermediate_list = op_info.output_intermediate_list
Expand All @@ -480,7 +482,7 @@ def _gen_out_split_and_ret_list(self, op_info, op_inst_name):
):
if intermediate == 'true':
continue
if self._is_optinonal_output(op_info, name):
if self._is_optional_output(op_info, op_name, name):
ret_list.append(f'optional_{name}')
elif VECTOR_TYPE in type:
split_op_name = f'{name}_split_op'
Expand All @@ -503,7 +505,7 @@ def _gen_return_result(self, ret_list):
def _gen_one_impl(
self, op_info, op_name, is_mutable_attr, is_vector_mutable_attr
):
ret_type = self._gen_ret_type(op_info)
ret_type = self._gen_ret_type(op_info, op_name)
in_combine, in_combine_op_list = self._gen_in_combine(
op_info, is_mutable_attr, is_vector_mutable_attr
)
Expand All @@ -514,7 +516,7 @@ def _gen_one_impl(
compute_op += f' (void){op_inst_name};'

out_split, ret_list = self._gen_out_split_and_ret_list(
op_info, op_inst_name
op_info, op_name, op_inst_name
)

ret = API_IMPL_TEMPLATE.format(
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def GenBuildInputArgsStr(
'int': 'phi::DataType::INT32',
'int64_t': 'phi::DataType::INT64',
'float': 'phi::DataType::FLOAT32',
'double': 'phi::DataType::FLOAT64',
'std::vector<int64_t>': 'phi::DataType::INT64',
'const std::vector<int64_t>&': 'phi::DataType::INT64',
'bool': 'phi::DataType::BOOL',
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def parse_mutable_attribute(self):
if (self.op_compat_item['op'] == "isclose") or (
self.op_compat_item['op'] == "allclose"
):
data_type = "float"
data_type = "double"
mutable_attribute_type_list.append(
[
"paddle::dialect::ScalarAttribute",
Expand Down
19 changes: 19 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,24 @@ pir::OpResult split_with_num_grad(const std::vector<pir::Value>& out_grad,
out_grad_combine_op.out(), axis);
return split_grad_op.result(0);
}

pir::OpResult ones(const std::vector<int64_t>& shape,
phi::DataType dtype,
const Place& place) {
return paddle::dialect::full(shape, 1, dtype, place);
}

pir::OpResult ones_like(pir::Value x_,
phi::DataType dtype,
const Place& place) {
return paddle::dialect::full_like(x_, 1, dtype, place);
}

pir::OpResult zeros(const std::vector<int64_t>& shape,
phi::DataType dtype,
const Place& place) {
return paddle::dialect::full(shape, 0, dtype, place);
}

} // namespace dialect
} // namespace paddle
13 changes: 13 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,18 @@ pir::OpResult split_with_num_grad(const std::vector<pir::Value>& out_grad,

pir::OpResult split_with_num_grad(const std::vector<pir::Value>& out_grad,
const pir::Value& axis);

pir::OpResult ones(const std::vector<int64_t>& shape,
phi::DataType dtype = phi::DataType::FLOAT32,
const Place& place = phi::CPUPlace());

pir::OpResult ones_like(pir::Value x_,
phi::DataType dtype = phi::DataType::UNDEFINED,
const Place& place = {});

pir::OpResult zeros(const std::vector<int64_t>& shape,
phi::DataType dtype = phi::DataType::FLOAT32,
const Place& place = phi::CPUPlace());

} // namespace dialect
} // namespace paddle
9 changes: 8 additions & 1 deletion paddle/fluid/primitive/backend/manual/manual_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,21 @@ namespace primitive {
namespace backend {

using Tensor = paddle::Tensor;
using Scalar = paddle::experimental::Scalar;
using Scalar = phi::Scalar;
using IntArray = paddle::experimental::IntArray;
using DataType = phi::DataType;

template <typename T>
std::vector<Tensor> add_n_grad(const std::vector<Tensor>& x,
const Tensor& out_grad);

template <typename T>
Tensor embedding_grad(const Tensor& x,
const Tensor& weight,
const Tensor& out_grad,
int64_t padding_idx = -1,
bool sparse = false);

} // namespace backend
} // namespace primitive
} // namespace paddle
17 changes: 17 additions & 0 deletions paddle/fluid/primitive/backend/manual/manual_static_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,23 @@ std::vector<Tensor> add_n_grad<LazyTensor>(const std::vector<Tensor>& x,
return x_grad;
}

template <>
Tensor embedding_grad<LazyTensor>(const Tensor& x,
const Tensor& weight,
const Tensor& out_grad,
int64_t padding_idx,
bool sparse) {
pir::Value x_res = std::static_pointer_cast<LazyTensor>(x.impl())->value();
pir::Value weight_res =
std::static_pointer_cast<LazyTensor>(weight.impl())->value();
pir::Value out_grad_res =
std::static_pointer_cast<LazyTensor>(out_grad.impl())->value();
auto op_res = paddle::dialect::embedding_grad(
x_res, weight_res, out_grad_res, padding_idx, sparse);
Tensor out(std::make_shared<LazyTensor>(op_res));
return out;
}

} // namespace backend
} // namespace primitive
} // namespace paddle
Loading