Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【pir】 add tensorarray op createarrylike, add_n #60460

Merged
merged 48 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
001d799
optimize backward
xiaoguoguo626807 Dec 8, 2023
05ca298
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 11, 2023
4fd113e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 12, 2023
8f60538
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 13, 2023
8854896
[PIR] add vjp interface for while op
winter-wang Dec 12, 2023
7e177f6
[PIR] fix ci error.
winter-wang Dec 13, 2023
11c8656
modify while stopgradient
xiaoguoguo626807 Dec 14, 2023
d8c3936
merge
xiaoguoguo626807 Dec 14, 2023
da62e16
merge
xiaoguoguo626807 Dec 15, 2023
67ed811
merge
xiaoguoguo626807 Dec 15, 2023
30bba32
modify while grad bug
xiaoguoguo626807 Dec 18, 2023
53f2920
merge
xiaoguoguo626807 Dec 18, 2023
fde161c
modify while grad op
xiaoguoguo626807 Dec 18, 2023
fdc12c7
modify
xiaoguoguo626807 Dec 18, 2023
e3d19b9
increment vp
xiaoguoguo626807 Dec 19, 2023
600d99c
merge
xiaoguoguo626807 Dec 20, 2023
0913436
[PIR] add get_used_external_value interface for block.
winter-wang Dec 19, 2023
63344b7
while case
xiaoguoguo626807 Dec 20, 2023
59ad2fc
delete print
xiaoguoguo626807 Dec 20, 2023
f4eceb6
delete print
xiaoguoguo626807 Dec 20, 2023
1c9eb96
Update python/paddle/autograd/ir_backward.py
xiaoguoguo626807 Dec 20, 2023
4beaa79
Merge branch 'develop' into while_2
xiaoguoguo626807 Dec 20, 2023
df0b46a
[PIR] add unit_test for get_used_external_value
winter-wang Dec 20, 2023
65083df
modify while_loop
xiaoguoguo626807 Dec 21, 2023
f2f4fa0
Merge branch 'while_2' of https://github.com/xiaoguoguo626807/Paddle …
xiaoguoguo626807 Dec 21, 2023
f8e3ac4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 21, 2023
95bc3d7
code_style
xiaoguoguo626807 Dec 21, 2023
37e807c
modofy ci bug
xiaoguoguo626807 Dec 21, 2023
52afa31
Merge branch 'develop', commit 'refs/pull/60159/head' of https://gith…
xiaoguoguo626807 Dec 21, 2023
48de124
modify while api
xiaoguoguo626807 Dec 22, 2023
a7f13c9
merge
xiaoguoguo626807 Dec 25, 2023
adb627a
modify ci
xiaoguoguo626807 Dec 25, 2023
e90cd79
modify array
xiaoguoguo626807 Dec 26, 2023
17e17d4
merge
xiaoguoguo626807 Dec 26, 2023
1aa50c0
Update python/paddle/autograd/ir_backward.py
xiaoguoguo626807 Dec 26, 2023
eef3e24
Update test/legacy_test/test_cond.py
xiaoguoguo626807 Dec 26, 2023
d78b574
update
xiaoguoguo626807 Dec 26, 2023
d404059
modify array_write grad info
xiaoguoguo626807 Dec 26, 2023
fb8c52d
merge
xiaoguoguo626807 Dec 26, 2023
f3e09e5
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 26, 2023
44d856f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 27, 2023
39fcb4b
merge
xiaoguoguo626807 Dec 27, 2023
655482a
add_n and createarraylike
xiaoguoguo626807 Dec 29, 2023
ec43be4
merge
xiaoguoguo626807 Dec 29, 2023
785d367
conflict
xiaoguoguo626807 Dec 29, 2023
b6e2388
modify exe bug
xiaoguoguo626807 Dec 29, 2023
5315369
modify kernel choose
xiaoguoguo626807 Jan 2, 2024
5f60450
fix conflict
xiaoguoguo626807 Jan 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,13 @@ void BuildPhiContext(pir::Operation* op,
} else if (variable_array[i]->IsType<phi::SelectedRows>()) {
inputs.emplace_back(InType(const_cast<phi::SelectedRows*>(
&(variable_array[i]->Get<phi::SelectedRows>()))));
} else if (variable_array[i]->IsType<phi::TensorArray>()) {
inputs.emplace_back(InType(const_cast<phi::TensorArray*>(
&(variable_array[i]->Get<phi::TensorArray>()))));
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support Vector<DenseTensor> and vector<SelectedRows> now, "
"Only support Vector<DenseTensor> and vector<SelectedRows> "
"and vector<TensorArray> now "
"not support vector<%d>.",
variable_array[i]->Type()));
}
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ pir::OpResult create_array(phi::DataType dtype) {
return create_array_op.out();
}

pir::OpResult create_array_like(pir::Value input, float value) {
auto create_array_like_op =
ApiBuilder::Instance()
.GetBuilder()
->Build<paddle::dialect::CreateArrayLikeOp>(input, value);
return create_array_like_op.out();
}

pir::OpResult array_length(pir::Value x) {
auto array_length_op = ApiBuilder::Instance()
.GetBuilder()
Expand Down Expand Up @@ -165,6 +173,15 @@ std::tuple<pir::OpResult, pir::OpResult> array_to_tensor(pir::Value x,
return std::make_tuple(array_to_tensor.result(0), array_to_tensor.result(1));
}

pir::OpResult add_n_array(const std::vector<pir::Value>& inputs) {
auto inputs_combine_op =
ApiBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(inputs);
paddle::dialect::AddNArrayOp add_n_array_op =
ApiBuilder::Instance().GetBuilder()->Build<paddle::dialect::AddNArrayOp>(
inputs_combine_op.out());
return add_n_array_op.result(0);
}

pir::OpResult slice_array_dense(pir::Value input, pir::Value starts) {
auto op = ApiBuilder::Instance()
.GetBuilder()
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ pir::OpResult zeros(const std::vector<int64_t>& shape,

pir::OpResult create_array(phi::DataType dtype);

pir::OpResult create_array_like(pir::Value input, float value);

pir::OpResult array_length(pir::Value x);

pir::OpResult array_read(pir::Value array, pir::Value i);
Expand All @@ -72,6 +74,8 @@ std::tuple<pir::OpResult, pir::OpResult> array_to_tensor(pir::Value x,
int axis,
bool use_stack);

pir::OpResult add_n_array(const std::vector<pir::Value>& inputs);

pir::OpResult slice_array_dense(pir::Value input, pir::Value starts);

} // namespace dialect
Expand Down
255 changes: 249 additions & 6 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
#ifdef GET_OP_LIST
#undef GET_OP_LIST
paddle::dialect::AddNOp, paddle::dialect::AddN_Op,
paddle::dialect::AddNWithKernelOp, paddle::dialect::FusedGemmEpilogueOp,
paddle::dialect::AddNWithKernelOp, paddle::dialect::AddNArrayOp,
paddle::dialect::FusedGemmEpilogueOp,
paddle::dialect::FusedGemmEpilogueGradOp, paddle::dialect::SplitGradOp,
paddle::dialect::ExpandOp, paddle::dialect::CreateArrayOp,
paddle::dialect::ArrayLengthOp, paddle::dialect::ArrayReadOp,
paddle::dialect::ArrayWrite_Op, paddle::dialect::SliceArrayOp,
paddle::dialect::SliceArrayDenseOp, paddle::dialect::AssignArray_Op,
paddle::dialect::ArrayToTensorOp, paddle::dialect::SelectInputOp,
paddle::dialect::IncrementOp, paddle::dialect::Increment_Op
paddle::dialect::CreateArrayLikeOp, paddle::dialect::ArrayLengthOp,
paddle::dialect::ArrayReadOp, paddle::dialect::ArrayWrite_Op,
paddle::dialect::SliceArrayOp, paddle::dialect::SliceArrayDenseOp,
paddle::dialect::AssignArray_Op, paddle::dialect::ArrayToTensorOp,
paddle::dialect::SelectInputOp, paddle::dialect::IncrementOp,
paddle::dialect::Increment_Op
#else

#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
Expand Down Expand Up @@ -421,6 +423,136 @@ void AddNWithKernelOp::InferMeta(phi::InferMetaContext *infer_meta) {
fn(infer_meta);
}

OpInfoTuple AddNArrayOp::GetOpInfo() {
std::vector<paddle::dialect::OpInputInfo> inputs = {
OpInputInfo("inputs",
"pir::VectorType<paddle::dialect::DenseTensorArrayType>",
false,
false,
false,
true)};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {};
std::vector<paddle::dialect::OpOutputInfo> outputs = {OpOutputInfo(
"out", "paddle::dialect::DenseTensorArrayType", false, false)};
paddle::dialect::OpRunTimeInfo run_time_info =
OpRunTimeInfo("AddNTensorArrayInferMeta",
{"inputs"},
"add_n_array",
{"inputs"},
{},
{},
{},
{});

return std::make_tuple(
inputs, attributes, outputs, run_time_info, "add_n_array");
}

void AddNArrayOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: AddNArrayOp.";
VLOG(4) << "Verifying inputs:";
{
auto input_size = num_operands();
PADDLE_ENFORCE_EQ(
input_size,
1u,
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 1.", input_size));
if (auto vec_type =
(*this)->operand(0).type().dyn_cast<pir::VectorType>()) {
for (size_t i = 0; i < vec_type.size(); ++i) {
PADDLE_ENFORCE(vec_type[i].isa<paddle::dialect::DenseTensorArrayType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th input."));
}
} else {
PADDLE_ENFORCE((*this)
->operand(0)
.type()
.isa<paddle::dialect::DenseTensorArrayType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th input."));
}
}
VLOG(4) << "Verifying attributes:";
{
// Attributes num is 0, not need to check attributes type.
}
VLOG(4) << "Verifying outputs:";
{
auto output_size = num_results();
PADDLE_ENFORCE_EQ(
output_size,
1u,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", output_size));
PADDLE_ENFORCE(
(*this)->result(0).type().isa<paddle::dialect::DenseTensorArrayType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th output."));
}
VLOG(4) << "End Verifying for: AddNArrayOp.";
}

void AddNArrayOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value inputs_) {
VLOG(4) << "Start build AddNArrayOp";

VLOG(4) << "Builder construction inputs";
argument.AddInput(inputs_);

VLOG(4) << "Builder construction attributes";

VLOG(4) << "Builder construction outputs";
pir::VectorType inputs = inputs_.type().dyn_cast<pir::VectorType>();

std::vector<paddle::dialect::IrTensor> vec_dense_inputs;
for (size_t i = 0; i < inputs.size(); i++) {
vec_dense_inputs.push_back(paddle::dialect::IrTensor(
TransToPhiDataType(
inputs[i]
.dyn_cast<paddle::dialect::DenseTensorArrayType>()
.dtype()),
{},
inputs[i]
.dyn_cast<paddle::dialect::DenseTensorArrayType>()
.data_layout(),
{}));
}

std::vector<paddle::dialect::IrMetaTensor> vec_meta_inputs;
for (size_t i = 0; i < vec_dense_inputs.size(); i++) {
vec_meta_inputs.push_back(
paddle::dialect::IrMetaTensor(&vec_dense_inputs[i]));
}

std::vector<const phi::MetaTensor *> meta_inputs;
for (size_t i = 0; i < static_cast<size_t>(vec_meta_inputs.size()); i++) {
meta_inputs.push_back(&vec_meta_inputs[i]);
}

paddle::dialect::IrTensor dense_out;
paddle::dialect::IrMetaTensor meta_out(&dense_out);

phi::AddNTensorArrayInferMeta(
meta_inputs, &meta_out, phi::MetaConfig(false, false));
std::vector<pir::Type> argument_outputs;
pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorArrayType::get(
pir::IrContext::Instance(),
TransToIrDataType(dense_out.dtype()),
dense_out.layout());

argument_outputs.push_back(out_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void AddNArrayOp::InferMeta(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(phi::AddNTensorArrayInferMeta);
fn(infer_meta);
}

const char *FusedGemmEpilogueOp::attributes_name[3] = {
"trans_x", "trans_y", "activation"};

Expand Down Expand Up @@ -1156,6 +1288,114 @@ void CreateArrayOp::InferMeta(phi::InferMetaContext *infer_meta) {
fn(infer_meta);
}

const char *CreateArrayLikeOp::attributes_name[1] = {"val"};

OpInfoTuple CreateArrayLikeOp::GetOpInfo() {
std::vector<paddle::dialect::OpInputInfo> inputs = {
paddle::dialect::OpInputInfo("input",
"paddle::dialect::DenseTensorArrayType",
false,
false,
false,
false)};

std::vector<paddle::dialect::OpAttributeInfo> attributes = {
paddle::dialect::OpAttributeInfo("val", "pir::FloatAttribute", "")};

std::vector<paddle::dialect::OpOutputInfo> outputs = {OpOutputInfo(
"out", "paddle::dialect::DenseTensorArrayType", false, false)};

paddle::dialect::OpRunTimeInfo run_time_info =
OpRunTimeInfo("CreateArrayLikeInferMeta",
{"input"},
"create_array_like",
{"input", "val"},
{},
{},
{},
{});

return std::make_tuple(
inputs, attributes, outputs, run_time_info, "create_array_like");
}

void CreateArrayLikeOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value &input_, // NOLINT
float &val) {
VLOG(4) << "Start build CreateArrayLikeOp";
VLOG(4) << "Builder construction inputs";
std::vector<pir::Value> argument_inputs = {input_};
argument.AddInputs(argument_inputs);

VLOG(4) << "Builder construction attributes";
pir::Attribute attr_val =
pir::FloatAttribute::get(pir::IrContext::Instance(), val);
argument.AddAttribute("val", attr_val);
VLOG(4) << "Builder construction outputs";
paddle::dialect::DenseTensorArrayType input_type =
input_.type().dyn_cast<paddle::dialect::DenseTensorArrayType>();
paddle::dialect::IrTensor dense_input(
paddle::dialect::TransToPhiDataType(input_type.dtype()),
{},
input_type.data_layout(),
{});

paddle::dialect::IrMetaTensor meta_input(&dense_input);

paddle::dialect::IrTensor dense_out;
paddle::dialect::IrMetaTensor meta_out(&dense_out);

phi::CreateArrayLikeInferMeta(meta_input, &meta_out);

std::vector<pir::Type> argument_outputs;
pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorArrayType::get(
pir::IrContext::Instance(),
paddle::dialect::TransToIrDataType(dense_out.dtype()),
dense_out.layout());
argument_outputs.push_back(out_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void CreateArrayLikeOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: "
"CreateArrayLikeOp.";
VLOG(4) << "Verifying inputs:";
{
auto input_size = num_operands();
PADDLE_ENFORCE_EQ(
input_size,
1u,
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 1.", input_size));
}
VLOG(4) << "Verifying attributes:";
{
auto &attributes = this->attributes();
PADDLE_ENFORCE(attributes.count("val") > 0, "val does not exist.");
}
VLOG(4) << "Verifying outputs:";
{
auto output_size = num_results();
PADDLE_ENFORCE_EQ(
output_size,
1u,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", output_size));
PADDLE_ENFORCE(
(*this)->result(0).type().isa<paddle::dialect::DenseTensorArrayType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th output."));
}
VLOG(4) << "End Verifying for: CreateArrayLikeOp.";
}

void CreateArrayLikeOp::InferMeta(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(phi::CreateArrayLikeInferMeta);
fn(infer_meta);
}

OpInfoTuple ArrayLengthOp::GetOpInfo() {
std::vector<paddle::dialect::OpInputInfo> inputs = {
OpInputInfo("x",
Expand Down Expand Up @@ -1319,6 +1559,7 @@ void ArrayReadOp::Build(pir::Builder &builder,
dense_out.lod());
argument_outputs.push_back(out_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void ArrayReadOp::Build(pir::Builder &builder,
Expand Down Expand Up @@ -2691,9 +2932,11 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNArrayOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CreateArrayOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CreateArrayLikeOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayLengthOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayReadOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayWrite_Op)
Expand Down
Loading