Skip to content

Commit

Permalink
[IR] polish the new ir api name. (#54562)
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored Jun 13, 2023
1 parent 53f2466 commit eac99c5
Show file tree
Hide file tree
Showing 13 changed files with 83 additions and 116 deletions.
8 changes: 4 additions & 4 deletions paddle/fluid/ir/dialect/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
"static const char *attributes_name[{attribute_num}];"
)

OP_GET_INPUT_TEMPLATE = """ ir::OpOperand {input_name}() {{ return operation()->GetOperandByIndex({input_index}); }}
OP_GET_INPUT_TEMPLATE = """ ir::OpOperand {input_name}() {{ return operation()->operand({input_index}); }}
"""
OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return operation()->GetResultByIndex({output_index}); }}
OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return operation()->result({output_index}); }}
"""

# =====================================
Expand Down Expand Up @@ -817,11 +817,11 @@ def GenBuildInserFullForMutableAttribute(
build_mutable_attribute = ""
BUILD_INTARRAY_ATTRIBUTE_TEMPLATE = """ // Generate int_array mutable attribute: {attr_name}
paddle::dialect::FullIntArrayOp full_{attr_name}_op = builder.Build<paddle::dialect::FullIntArrayOp>({attr_name}, {phi_dtype}, phi::CPUPlace());
ir::OpResult {attr_name}_ = full_{attr_name}_op->GetResultByIndex(0);
ir::OpResult {attr_name}_ = full_{attr_name}_op->result(0);
"""
BUILD_SCALAR_ATTRIBUTE_TEMPLATE = """ // Generate scalar mutable attribute: {attr_name}
paddle::dialect::FullOp full_{attr_name}_op = builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{{1}}, {attr_name}, {phi_dtype}, phi::CPUPlace());
ir::OpResult {attr_name}_ = full_{attr_name}_op->GetResultByIndex(0);
ir::OpResult {attr_name}_ = full_{attr_name}_op->result(0);
"""
for idx in range(len(op_mutable_attribute_name_list)):
attr_name = op_mutable_attribute_name_list[idx]
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ inline ir::Operation* InsertSliceOperationForTarget(
{src_vec_type[defining_info.idx_in_vector]},
op_info);
program->block()->push_back(operation);
ir::OpResult target_op_result = operation->GetResultByIndex(0);
ir::OpResult target_op_result = operation->result(0);
(*param_map)[arg_name] = VariableDefiningInfo(target_op_result);
return operation;
}
Expand Down Expand Up @@ -190,7 +190,7 @@ inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx,
data = static_cast<float>(attr.dyn_cast<ir::BoolAttribute>().data());
dtype = phi::DataType::BOOL;
}
ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program->block());
ir::Builder builder(ctx, program->block());
paddle::dialect::FullOp full_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, data, dtype, phi::CPUPlace());

Expand All @@ -206,7 +206,7 @@ inline ir::Operation* InsertFullArrayOperationForAttributeInput(
phi::IntArray int_array =
attr.dyn_cast<paddle::dialect::IntArrayAttribute>().data();

ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program->block());
ir::Builder builder(ctx, program->block());
paddle::dialect::FullIntArrayOp full_int_array_op =
builder.Build<paddle::dialect::FullIntArrayOp>(
int_array.GetData(), phi::DataType::INT64, phi::CPUPlace());
Expand Down Expand Up @@ -244,7 +244,7 @@ inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx,
defining_op = InsertFullOperationForAttributeInput(ctx, program, new_attr);
}

return defining_op->GetResultByIndex(0);
return defining_op->result(0);
}

inline std::vector<ir::OpResult> GenerateOperationInput(
Expand Down Expand Up @@ -340,7 +340,7 @@ inline std::vector<ir::OpResult> GenerateOperationInput(
} else {
auto* combine_op = InsertCombineOperationForTarget(
ctx, param_map, program, legacy_input_vars);
op_inputs.push_back(combine_op->GetResultByIndex(0));
op_inputs.push_back(combine_op->result(0));
}
}

Expand Down Expand Up @@ -472,7 +472,7 @@ inline void RecordOpResultMapping(TranslationContext* param_map,
VLOG(10) << "[output recording]"
<< "[" << op_desc.Type() << "]" << arg_name << " " << idx;

ir::OpResult value = operation->GetResultByIndex(idx);
ir::OpResult value = operation->result(idx);
bool generated_by_vector = value.type().isa<ir::VectorType>();
(*param_map)[arg_name] = VariableDefiningInfo(
value, generated_by_vector, generated_by_vector ? idx_in_vector : -1);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/ir_adaptor/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) {
ir::Operation* op =
InsertGetParamaterOp(ctx, parameter_name_mappings[var_name]);
program->block()->push_back(op);
param_map[var_name] = VariableDefiningInfo(op->GetResultByIndex(0));
param_map[var_name] = VariableDefiningInfo(op->result(0));
VLOG(10) << "[op translated][get parameter]" << op;

program->SetParameter(var_name, nullptr);
Expand Down
8 changes: 0 additions & 8 deletions paddle/ir/core/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,6 @@ class Builder {
Builder(IrContext *context, Block *block)
: Builder(context, block, block->end()) {}

static Builder AtBlockBegin(IrContext *context, Block *block) {
return Builder(context, block, block->begin());
}

static Builder AtBlockEnd(IrContext *context, Block *block) {
return Builder(context, block, block->end());
}

IrContext *context() const { return context_; }

Block *block() const { return block_; }
Expand Down
10 changes: 5 additions & 5 deletions paddle/ir/core/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ void IrPrinter::PrintOpResult(Operation* op) {
std::vector<OpResult> op_results;
op_results.reserve(num_op_result);
for (size_t idx = 0; idx < num_op_result; idx++) {
op_results.push_back(op->GetResultByIndex(idx));
op_results.push_back(op->result(idx));
}
PrintInterleave(
op_results.begin(),
Expand Down Expand Up @@ -230,7 +230,7 @@ void IrPrinter::PrintOpOperands(Operation* op) {
std::vector<Value> op_operands;
op_operands.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
op_operands.push_back(op->GetOperandByIndex(idx).source());
op_operands.push_back(op->operand(idx).source());
}
PrintInterleave(
op_operands.begin(),
Expand All @@ -245,9 +245,9 @@ void IrPrinter::PrintOperandsType(Operation* op) {
std::vector<Type> op_operand_types;
op_operand_types.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
auto op_operand = op->GetOperandByIndex(idx);
auto op_operand = op->operand(idx);
if (op_operand) {
op_operand_types.push_back(op->GetOperandByIndex(idx).source().type());
op_operand_types.push_back(op->operand(idx).source().type());
} else {
op_operand_types.push_back(Type(nullptr));
}
Expand All @@ -266,7 +266,7 @@ void IrPrinter::PrintOpReturnType(Operation* op) {
std::vector<Type> op_result_types;
op_result_types.reserve(num_op_result);
for (size_t idx = 0; idx < num_op_result; idx++) {
auto op_result = op->GetResultByIndex(idx);
auto op_result = op->result(idx);
if (op_result) {
op_result_types.push_back(op_result.type());
} else {
Expand Down
4 changes: 2 additions & 2 deletions paddle/ir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ Operation::Operation(const AttributeMap &attributes,
num_operands_(num_operands),
num_regions_(num_regions) {}

ir::OpResult Operation::GetResultByIndex(uint32_t index) const {
ir::OpResult Operation::result(uint32_t index) const {
if (index >= num_results_) {
IR_THROW("index exceeds OP output range.");
}
Expand All @@ -200,7 +200,7 @@ ir::OpResult Operation::GetResultByIndex(uint32_t index) const {
}
}

ir::OpOperand Operation::GetOperandByIndex(uint32_t index) const {
ir::OpOperand Operation::operand(uint32_t index) const {
if (index >= num_operands_) {
IR_THROW("index exceeds OP input range.");
}
Expand Down
6 changes: 4 additions & 2 deletions paddle/ir/core/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ class alignas(8) Operation final {
void Destroy();

IrContext *ir_context() const;

Dialect *dialect() const;
OpResult GetResultByIndex(uint32_t index) const;

OpOperand GetOperandByIndex(uint32_t index) const;
OpResult result(uint32_t index) const;

OpOperand operand(uint32_t index) const;

void Print(std::ostream &os);

Expand Down
4 changes: 2 additions & 2 deletions paddle/ir/pattern_rewrite/pattern_match.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ RewriterBase::~RewriterBase() = default;
// // assert(op->num_results() == new_values.size() && "incorrect number of
// values to replace operation"); NotifyRootReplaced(op, new_values); bool
// replace_all_uses = true; for (uint32_t i = 0; i < op->num_results(); ++i) {
// // op->GetResultByIndex(0)
// // op->result(0)
// }
// }
// void RewriterBase::ReplaceOpWithIf(Operation* op,
Expand Down Expand Up @@ -138,7 +138,7 @@ void RewriterBase::ReplaceOpWithResultsOfAnotherOp(Operation* op,
"replacement op doesn't match results of original op");
// TODO(wilber): Op support results method.
// if (op->num_results() == 1) return ReplaceOp(op,
// new_op->GetResultByIndex(0)); return ReplaceOp(op, new_op->GetResults());
// new_op->result(0)); return ReplaceOp(op, new_op->GetResults());
}

} // namespace ir
17 changes: 6 additions & 11 deletions test/cpp/ir/core/ir_exe_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ TEST(program_test, program) {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx);
ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program.block());
ir::Builder builder(ctx, program.block());
ir::Block* block = program.block();

// Def: A = paddle::dialect::UniformOp(std::vector<int64_t> shape,
Expand All @@ -68,9 +68,7 @@ TEST(program_test, program) {
1.0,
2,
phi::CPUPlace());
EXPECT_EQ(uniform1->GetResultByIndex(0)
.type()
.isa<paddle::dialect::DenseTensorType>(),
EXPECT_EQ(uniform1->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 4u);

Expand All @@ -82,18 +80,15 @@ TEST(program_test, program) {
1.0,
2,
phi::CPUPlace());
EXPECT_EQ(uniform2->GetResultByIndex(0)
.type()
.isa<paddle::dialect::DenseTensorType>(),
EXPECT_EQ(uniform2->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 8u);

// Def: C = paddle::dialect::AddOp(ir::OpResult x_, ir::OpResult y_)
paddle::dialect::AddOp add = builder.Build<paddle::dialect::AddOp>(
uniform1->GetResultByIndex(0), uniform2->GetResultByIndex(0));
EXPECT_EQ(
add->GetResultByIndex(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
uniform1->result(0), uniform2->result(0));
EXPECT_EQ(add->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 9u);

// Execute program
Expand Down
52 changes: 20 additions & 32 deletions test/cpp/ir/core/ir_program_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,10 @@ TEST(program_test, program) {

EXPECT_EQ(&program, op1->GetParentProgram());

EXPECT_EQ(op1->GetResultByIndex(0).type().dialect().id(),
paddle_dialect->id());
EXPECT_EQ(op1->result(0).type().dialect().id(), paddle_dialect->id());
using Interface = paddle::dialect::ParameterConvertInterface;
Interface *a_interface = op1->GetResultByIndex(0)
.type()
.dialect()
.GetRegisteredInterface<Interface>();
Interface *a_interface =
op1->result(0).type().dialect().GetRegisteredInterface<Interface>();
std::shared_ptr<paddle::framework::Variable> a_var =
a_interface->ParameterToVariable(program.GetParameter("a"));
const phi::DenseTensor &a_tensor = a_var->Get<phi::DenseTensor>();
Expand All @@ -134,12 +131,9 @@ TEST(program_test, program) {
ir::Operation::Create({}, op2_attribute, {dense_tensor_dtype}, op2_info);
block->push_back(op2);

EXPECT_EQ(op2->GetResultByIndex(0).type().dialect().id(),
paddle_dialect->id());
Interface *b_interface = op2->GetResultByIndex(0)
.type()
.dialect()
.GetRegisteredInterface<Interface>();
EXPECT_EQ(op2->result(0).type().dialect().id(), paddle_dialect->id());
Interface *b_interface =
op2->result(0).type().dialect().GetRegisteredInterface<Interface>();
std::shared_ptr<paddle::framework::Variable> b_var =
b_interface->ParameterToVariable(program.GetParameter("b"));
const phi::DenseTensor &b_tensor = b_var->Get<phi::DenseTensor>();
Expand All @@ -158,11 +152,10 @@ TEST(program_test, program) {
builtin_dialect->name() + "." + std::string(AddOp::name());
ir::OpInfo op3_info = ctx->GetRegisteredOpInfo(op3_name);
std::unordered_map<std::string, ir::Attribute> op3_attribute;
ir::Operation *op3 = ir::Operation::Create(
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)},
op3_attribute,
{dense_tensor_dtype},
op3_info);
ir::Operation *op3 = ir::Operation::Create({op1->result(0), op2->result(0)},
op3_attribute,
{dense_tensor_dtype},
op3_info);
block->push_back(op3);

phi::CPUContext *dev_ctx = static_cast<phi::CPUContext *>(
Expand All @@ -186,7 +179,7 @@ TEST(program_test, program) {

// (7) Def AbsOp(b)
ir::OpInfo abs_info = ctx->GetRegisteredOpInfo("pd.abs");
std::vector<ir::OpResult> operands = {op1->GetResultByIndex(0)};
std::vector<ir::OpResult> operands = {op1->result(0)};
std::unordered_map<std::string, ir::Attribute> abs_op_attribute;
std::vector<ir::Type> output_types = {dense_tensor_dtype};
ir::OperationArgument abs_argument(abs_info);
Expand All @@ -205,15 +198,14 @@ TEST(program_test, program) {
std::unordered_map<std::string, ir::Attribute> op4_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "c")}};

ir::OperationArgument op4_argument(
{op3->GetResultByIndex(0)}, {}, {}, op4_info);
ir::OperationArgument op4_argument({op3->result(0)}, {}, {}, op4_info);
op4_argument.AddAttributes(op4_attribute.begin(), op4_attribute.end());
ir::Operation *op4 = ir::Operation::Create(std::move(op4_argument));
block->push_back(op4);

EXPECT_EQ(op4->GetOperandByIndex(0).source().type().dialect().id(),
EXPECT_EQ(op4->operand(0).source().type().dialect().id(),
paddle_dialect->id());
Interface *c_interface = op4->GetOperandByIndex(0)
Interface *c_interface = op4->operand(0)
.source()
.type()
.dialect()
Expand Down Expand Up @@ -274,21 +266,17 @@ TEST(program_test, slice_combine_test) {
ir::Type output_type =
ir::VectorType::get(ctx, std::vector<ir::Type>({fp32_dtype, fp32_dtype}));
ir::Operation *combine_op = ir::Operation::Create(
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)},
{},
{output_type},
combine_op_info);
{op1->result(0), op2->result(0)}, {}, {output_type}, combine_op_info);
program.block()->push_back(combine_op);

// (7) Def slice_op = SliceOp(combine_op, 0)
std::string slice_op_name = std::string(ir::SliceOp::name());
ir::OpInfo slice_op_info = ctx->GetRegisteredOpInfo(slice_op_name);
ir::Attribute index_attr = ir::Int32Attribute::get(ctx, 0);
ir::Operation *slice_op =
ir::Operation::Create({combine_op->GetResultByIndex(0)},
{{"index", index_attr}},
{fp32_dtype},
slice_op_info);
ir::Operation *slice_op = ir::Operation::Create({combine_op->result(0)},
{{"index", index_attr}},
{fp32_dtype},
slice_op_info);
program.block()->push_back(slice_op);

// (8) Traverse Program
Expand All @@ -303,7 +291,7 @@ TEST(program_test, builder) {

paddle::dialect::FullOp full_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2, 2}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace());
ir::Type full_op_output = full_op->GetResultByIndex(0).type();
ir::Type full_op_output = full_op->result(0).type();
EXPECT_EQ(program.block()->size(), 1u);
EXPECT_EQ(program.block()->back(), full_op.operation());
EXPECT_EQ(full_op->num_operands(), 0u);
Expand Down
Loading

0 comments on commit eac99c5

Please sign in to comment.