Skip to content

Commit

Permalink
【pir】 add tensorarray op createarrylike, add_n (#60460)
Browse files Browse the repository at this point in the history
* optimize backward

* [PIR] add vjp interface for while op

* [PIR] fix ci error.

* modify while stopgradient

* merge

* modify while grad bug

* modify while grad op

* modify

* increment vp

* [PIR] add get_used_external_value interface for block.

* while case

* delete print

* delete print

* Update python/paddle/autograd/ir_backward.py

* [PIR] add unit_test for get_used_external_value

* modify while_loop

* code_style

* modofy ci bug

* modify while api

* modify ci

* modify array

* Update python/paddle/autograd/ir_backward.py

* Update test/legacy_test/test_cond.py

* update

* modify array_write grad info

* merge

* add_n and createarraylike

* conflict

* modify exe bug

* modify kernel choose

---------

Co-authored-by: winter-wang <1030748926@qq.com>
  • Loading branch information
xiaoguoguo626807 and winter-wang authored Jan 3, 2024
1 parent be8bc1e commit deb5397
Show file tree
Hide file tree
Showing 15 changed files with 508 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,13 @@ void BuildPhiContext(pir::Operation* op,
} else if (variable_array[i]->IsType<phi::SelectedRows>()) {
inputs.emplace_back(InType(const_cast<phi::SelectedRows*>(
&(variable_array[i]->Get<phi::SelectedRows>()))));
} else if (variable_array[i]->IsType<phi::TensorArray>()) {
inputs.emplace_back(InType(const_cast<phi::TensorArray*>(
&(variable_array[i]->Get<phi::TensorArray>()))));
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support Vector<DenseTensor> and vector<SelectedRows> now, "
"Only support Vector<DenseTensor> and vector<SelectedRows> "
"and vector<TensorArray> now "
"not support vector<%d>.",
variable_array[i]->Type()));
}
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ pir::OpResult create_array(phi::DataType dtype) {
return create_array_op.out();
}

pir::OpResult create_array_like(pir::Value input, float value) {
auto create_array_like_op =
ApiBuilder::Instance()
.GetBuilder()
->Build<paddle::dialect::CreateArrayLikeOp>(input, value);
return create_array_like_op.out();
}

pir::OpResult array_length(pir::Value x) {
auto array_length_op = ApiBuilder::Instance()
.GetBuilder()
Expand Down Expand Up @@ -165,6 +173,15 @@ std::tuple<pir::OpResult, pir::OpResult> array_to_tensor(pir::Value x,
return std::make_tuple(array_to_tensor.result(0), array_to_tensor.result(1));
}

pir::OpResult add_n_array(const std::vector<pir::Value>& inputs) {
auto inputs_combine_op =
ApiBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(inputs);
paddle::dialect::AddNArrayOp add_n_array_op =
ApiBuilder::Instance().GetBuilder()->Build<paddle::dialect::AddNArrayOp>(
inputs_combine_op.out());
return add_n_array_op.result(0);
}

pir::OpResult slice_array_dense(pir::Value input, pir::Value starts) {
auto op = ApiBuilder::Instance()
.GetBuilder()
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ pir::OpResult zeros(const std::vector<int64_t>& shape,

pir::OpResult create_array(phi::DataType dtype);

pir::OpResult create_array_like(pir::Value input, float value);

pir::OpResult array_length(pir::Value x);

pir::OpResult array_read(pir::Value array, pir::Value i);
Expand All @@ -72,6 +74,8 @@ std::tuple<pir::OpResult, pir::OpResult> array_to_tensor(pir::Value x,
int axis,
bool use_stack);

pir::OpResult add_n_array(const std::vector<pir::Value>& inputs);

pir::OpResult slice_array_dense(pir::Value input, pir::Value starts);

} // namespace dialect
Expand Down
255 changes: 249 additions & 6 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
#ifdef GET_OP_LIST
#undef GET_OP_LIST
paddle::dialect::AddNOp, paddle::dialect::AddN_Op,
paddle::dialect::AddNWithKernelOp, paddle::dialect::FusedGemmEpilogueOp,
paddle::dialect::AddNWithKernelOp, paddle::dialect::AddNArrayOp,
paddle::dialect::FusedGemmEpilogueOp,
paddle::dialect::FusedGemmEpilogueGradOp, paddle::dialect::SplitGradOp,
paddle::dialect::ExpandOp, paddle::dialect::CreateArrayOp,
paddle::dialect::ArrayLengthOp, paddle::dialect::ArrayReadOp,
paddle::dialect::ArrayWrite_Op, paddle::dialect::SliceArrayOp,
paddle::dialect::SliceArrayDenseOp, paddle::dialect::AssignArray_Op,
paddle::dialect::ArrayToTensorOp, paddle::dialect::SelectInputOp,
paddle::dialect::IncrementOp, paddle::dialect::Increment_Op
paddle::dialect::CreateArrayLikeOp, paddle::dialect::ArrayLengthOp,
paddle::dialect::ArrayReadOp, paddle::dialect::ArrayWrite_Op,
paddle::dialect::SliceArrayOp, paddle::dialect::SliceArrayDenseOp,
paddle::dialect::AssignArray_Op, paddle::dialect::ArrayToTensorOp,
paddle::dialect::SelectInputOp, paddle::dialect::IncrementOp,
paddle::dialect::Increment_Op
#else

#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
Expand Down Expand Up @@ -421,6 +423,136 @@ void AddNWithKernelOp::InferMeta(phi::InferMetaContext *infer_meta) {
fn(infer_meta);
}

OpInfoTuple AddNArrayOp::GetOpInfo() {
std::vector<paddle::dialect::OpInputInfo> inputs = {
OpInputInfo("inputs",
"pir::VectorType<paddle::dialect::DenseTensorArrayType>",
false,
false,
false,
true)};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {};
std::vector<paddle::dialect::OpOutputInfo> outputs = {OpOutputInfo(
"out", "paddle::dialect::DenseTensorArrayType", false, false)};
paddle::dialect::OpRunTimeInfo run_time_info =
OpRunTimeInfo("AddNTensorArrayInferMeta",
{"inputs"},
"add_n_array",
{"inputs"},
{},
{},
{},
{});

return std::make_tuple(
inputs, attributes, outputs, run_time_info, "add_n_array");
}

void AddNArrayOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: AddNArrayOp.";
VLOG(4) << "Verifying inputs:";
{
auto input_size = num_operands();
PADDLE_ENFORCE_EQ(
input_size,
1u,
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 1.", input_size));
if (auto vec_type =
(*this)->operand(0).type().dyn_cast<pir::VectorType>()) {
for (size_t i = 0; i < vec_type.size(); ++i) {
PADDLE_ENFORCE(vec_type[i].isa<paddle::dialect::DenseTensorArrayType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th input."));
}
} else {
PADDLE_ENFORCE((*this)
->operand(0)
.type()
.isa<paddle::dialect::DenseTensorArrayType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th input."));
}
}
VLOG(4) << "Verifying attributes:";
{
// Attributes num is 0, not need to check attributes type.
}
VLOG(4) << "Verifying outputs:";
{
auto output_size = num_results();
PADDLE_ENFORCE_EQ(
output_size,
1u,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", output_size));
PADDLE_ENFORCE(
(*this)->result(0).type().isa<paddle::dialect::DenseTensorArrayType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th output."));
}
VLOG(4) << "End Verifying for: AddNArrayOp.";
}

void AddNArrayOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value inputs_) {
VLOG(4) << "Start build AddNArrayOp";

VLOG(4) << "Builder construction inputs";
argument.AddInput(inputs_);

VLOG(4) << "Builder construction attributes";

VLOG(4) << "Builder construction outputs";
pir::VectorType inputs = inputs_.type().dyn_cast<pir::VectorType>();

std::vector<paddle::dialect::IrTensor> vec_dense_inputs;
for (size_t i = 0; i < inputs.size(); i++) {
vec_dense_inputs.push_back(paddle::dialect::IrTensor(
TransToPhiDataType(
inputs[i]
.dyn_cast<paddle::dialect::DenseTensorArrayType>()
.dtype()),
{},
inputs[i]
.dyn_cast<paddle::dialect::DenseTensorArrayType>()
.data_layout(),
{}));
}

std::vector<paddle::dialect::IrMetaTensor> vec_meta_inputs;
for (size_t i = 0; i < vec_dense_inputs.size(); i++) {
vec_meta_inputs.push_back(
paddle::dialect::IrMetaTensor(&vec_dense_inputs[i]));
}

std::vector<const phi::MetaTensor *> meta_inputs;
for (size_t i = 0; i < static_cast<size_t>(vec_meta_inputs.size()); i++) {
meta_inputs.push_back(&vec_meta_inputs[i]);
}

paddle::dialect::IrTensor dense_out;
paddle::dialect::IrMetaTensor meta_out(&dense_out);

phi::AddNTensorArrayInferMeta(
meta_inputs, &meta_out, phi::MetaConfig(false, false));
std::vector<pir::Type> argument_outputs;
pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorArrayType::get(
pir::IrContext::Instance(),
TransToIrDataType(dense_out.dtype()),
dense_out.layout());

argument_outputs.push_back(out_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void AddNArrayOp::InferMeta(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(phi::AddNTensorArrayInferMeta);
fn(infer_meta);
}

const char *FusedGemmEpilogueOp::attributes_name[3] = {
"trans_x", "trans_y", "activation"};

Expand Down Expand Up @@ -1156,6 +1288,114 @@ void CreateArrayOp::InferMeta(phi::InferMetaContext *infer_meta) {
fn(infer_meta);
}

const char *CreateArrayLikeOp::attributes_name[1] = {"val"};

OpInfoTuple CreateArrayLikeOp::GetOpInfo() {
std::vector<paddle::dialect::OpInputInfo> inputs = {
paddle::dialect::OpInputInfo("input",
"paddle::dialect::DenseTensorArrayType",
false,
false,
false,
false)};

std::vector<paddle::dialect::OpAttributeInfo> attributes = {
paddle::dialect::OpAttributeInfo("val", "pir::FloatAttribute", "")};

std::vector<paddle::dialect::OpOutputInfo> outputs = {OpOutputInfo(
"out", "paddle::dialect::DenseTensorArrayType", false, false)};

paddle::dialect::OpRunTimeInfo run_time_info =
OpRunTimeInfo("CreateArrayLikeInferMeta",
{"input"},
"create_array_like",
{"input", "val"},
{},
{},
{},
{});

return std::make_tuple(
inputs, attributes, outputs, run_time_info, "create_array_like");
}

void CreateArrayLikeOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value &input_, // NOLINT
float &val) {
VLOG(4) << "Start build CreateArrayLikeOp";
VLOG(4) << "Builder construction inputs";
std::vector<pir::Value> argument_inputs = {input_};
argument.AddInputs(argument_inputs);

VLOG(4) << "Builder construction attributes";
pir::Attribute attr_val =
pir::FloatAttribute::get(pir::IrContext::Instance(), val);
argument.AddAttribute("val", attr_val);
VLOG(4) << "Builder construction outputs";
paddle::dialect::DenseTensorArrayType input_type =
input_.type().dyn_cast<paddle::dialect::DenseTensorArrayType>();
paddle::dialect::IrTensor dense_input(
paddle::dialect::TransToPhiDataType(input_type.dtype()),
{},
input_type.data_layout(),
{});

paddle::dialect::IrMetaTensor meta_input(&dense_input);

paddle::dialect::IrTensor dense_out;
paddle::dialect::IrMetaTensor meta_out(&dense_out);

phi::CreateArrayLikeInferMeta(meta_input, &meta_out);

std::vector<pir::Type> argument_outputs;
pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorArrayType::get(
pir::IrContext::Instance(),
paddle::dialect::TransToIrDataType(dense_out.dtype()),
dense_out.layout());
argument_outputs.push_back(out_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void CreateArrayLikeOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: "
"CreateArrayLikeOp.";
VLOG(4) << "Verifying inputs:";
{
auto input_size = num_operands();
PADDLE_ENFORCE_EQ(
input_size,
1u,
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 1.", input_size));
}
VLOG(4) << "Verifying attributes:";
{
auto &attributes = this->attributes();
PADDLE_ENFORCE(attributes.count("val") > 0, "val does not exist.");
}
VLOG(4) << "Verifying outputs:";
{
auto output_size = num_results();
PADDLE_ENFORCE_EQ(
output_size,
1u,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", output_size));
PADDLE_ENFORCE(
(*this)->result(0).type().isa<paddle::dialect::DenseTensorArrayType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th output."));
}
VLOG(4) << "End Verifying for: CreateArrayLikeOp.";
}

void CreateArrayLikeOp::InferMeta(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(phi::CreateArrayLikeInferMeta);
fn(infer_meta);
}

OpInfoTuple ArrayLengthOp::GetOpInfo() {
std::vector<paddle::dialect::OpInputInfo> inputs = {
OpInputInfo("x",
Expand Down Expand Up @@ -1319,6 +1559,7 @@ void ArrayReadOp::Build(pir::Builder &builder,
dense_out.lod());
argument_outputs.push_back(out_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void ArrayReadOp::Build(pir::Builder &builder,
Expand Down Expand Up @@ -2691,9 +2932,11 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNArrayOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CreateArrayOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CreateArrayLikeOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayLengthOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayReadOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayWrite_Op)
Expand Down
Loading

0 comments on commit deb5397

Please sign in to comment.