diff --git a/paddle/ir/core/block.cc b/paddle/ir/core/block.cc index 5012c7a2eb07f4..b74a960f7fc453 100644 --- a/paddle/ir/core/block.cc +++ b/paddle/ir/core/block.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/ir/core/block.h" +#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/operation.h" #include "paddle/ir/core/region.h" @@ -32,6 +33,12 @@ Block::iterator Block::insert(const_iterator iterator, Operation *op) { return iter; } +Block::iterator Block::erase(const_iterator position) { + IR_ENFORCE((*position)->GetParent() == this, "iterator not own this block."); + (*position)->Destroy(); + return ops_.erase(position); +} + void Block::clear() { while (!empty()) { ops_.back()->Destroy(); diff --git a/paddle/ir/core/block.h b/paddle/ir/core/block.h index dd4663d0744ebe..8777aaf40db7dc 100644 --- a/paddle/ir/core/block.h +++ b/paddle/ir/core/block.h @@ -50,6 +50,7 @@ class IR_API Block { void push_back(Operation *op); void push_front(Operation *op); iterator insert(const_iterator iterator, Operation *op); + iterator erase(const_iterator position); void clear(); operator Region::iterator() { return position_; } diff --git a/paddle/ir/core/builtin_op.cc b/paddle/ir/core/builtin_op.cc index df6f22b0b4406a..ed49b347780810 100644 --- a/paddle/ir/core/builtin_op.cc +++ b/paddle/ir/core/builtin_op.cc @@ -33,8 +33,8 @@ Program *ModuleOp::program() { Block *ModuleOp::block() { assert(operation() != nullptr); assert(operation()->num_regions() == 1); - assert(operation()->GetRegion(0).size() == 1); - return operation()->GetRegion(0).front(); + assert(operation()->region(0).size() == 1); + return operation()->region(0).front(); } ModuleOp ModuleOp::Create(IrContext *context, Program *pointer) { @@ -71,6 +71,15 @@ void ModuleOp::Verify(const std::vector &inputs, const char *GetParameterOp::attributes_name[attributes_num] = { "parameter_name"}; +void GetParameterOp::Build(Builder &builder, + OperationArgument &argument, + const std::string &name, + Type type) { + argument.attributes[attributes_name[0]] = + ir::StrAttribute::get(builder.ir_context(), name); + argument.output_types.emplace_back(type); +} + void GetParameterOp::Verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes) { @@ -90,6 +99,14 @@ void GetParameterOp::Verify(const std::vector &inputs, const char *SetParameterOp::attributes_name[attributes_num] = { "parameter_name"}; +void SetParameterOp::Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + OpResult parameter, + const std::string &name) { + argument.AddOperand(parameter); + argument.AddAttribute(attributes_name[0], + ir::StrAttribute::get(builder.ir_context(), name)); +} void SetParameterOp::Verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes) { @@ -106,6 +123,18 @@ void SetParameterOp::Verify(const std::vector &inputs, IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0."); } +void CombineOp::Build(Builder &builder, + OperationArgument &argument, + const std::vector &inputs) { + argument.inputs = inputs; + std::vector inputs_type(inputs.size()); + for (size_t idx = 0; idx < inputs.size(); ++idx) { + inputs_type[idx] = inputs[idx].type(); + } + argument.output_types.emplace_back( + ir::VectorType::get(builder.ir_context(), inputs_type)); +} + void CombineOp::Verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes) { diff --git a/paddle/ir/core/builtin_op.h b/paddle/ir/core/builtin_op.h index 13996397b37fcc..56cfafd35ffd68 100644 --- a/paddle/ir/core/builtin_op.h +++ b/paddle/ir/core/builtin_op.h @@ -54,8 +54,12 @@ class IR_API GetParameterOp : public ir::Op { static const char *name() { return "builtin.get_parameter"; } static constexpr uint32_t attributes_num = 1; static const char *attributes_name[attributes_num]; - static void Verify(const std::vector &inputs, - const std::vector &outputs, + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const std::string &name, + Type type); + static void Verify(const std::vector &inputs, + const std::vector &outputs, const ir::AttributeMap &attributes); }; @@ -69,6 +73,10 @@ class IR_API SetParameterOp : public ir::Op { static const char *name() { return "builtin.set_parameter"; } static constexpr uint32_t attributes_num = 1; static const char *attributes_name[attributes_num]; + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + OpResult parameter, + const std::string &name); static void Verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes); @@ -87,6 +95,10 @@ class IR_API CombineOp : public ir::Op { static constexpr const char **attributes_name = nullptr; + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const std::vector &inputs); + static void Verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes); diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index 2e611a5a76305b..c87bba1c8b3562 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -104,7 +104,7 @@ void BasicIrPrinter::PrintAttribute(const Attribute& attr) { void IrPrinter::PrintProgram(Program* program) { auto top_level_op = program->module_op(); for (size_t i = 0; i < top_level_op->num_regions(); ++i) { - auto& region = top_level_op->GetRegion(i); + auto& region = top_level_op->region(i); for (auto it = region.begin(); it != region.end(); ++it) { auto* block = *it; os << "{\n"; @@ -153,7 +153,7 @@ void IrPrinter::PrintFullOperation(Operation* op) { os << newline; } for (size_t i = 0; i < op->num_regions(); ++i) { - auto& region = op->GetRegion(i); + auto& region = op->region(i); PrintRegion(region); } } diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index 28de2403da8aad..991f8dbe2e107d 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -33,7 +33,7 @@ Operation *Operation::Create(OperationArgument &&argument) { argument.regions.size()); for (size_t index = 0; index < argument.regions.size(); ++index) { - op->GetRegion(index).TakeBody(std::move(*argument.regions[index])); + op->region(index).TakeBody(std::move(*argument.regions[index])); } return op; } @@ -103,17 +103,35 @@ Operation *Operation::Create(const std::vector &inputs, return op; } -// Call destructors for OpResults, Operation, and OpOperands in sequence, and -// finally free memory. +// Call destructors for Region , OpResults, Operation, and OpOperands in +// sequence, and finally free memory. void Operation::Destroy() { - // Deconstruct Regions. + // 1. Deconstruct Regions. if (num_regions_ > 0) { for (size_t idx = 0; idx < num_regions_; idx++) { regions_[idx].~Region(); } } - // 1. Get aligned_ptr by result_num. + // 2. Deconstruct Result. + for (size_t idx = 0; idx < num_results_; ++idx) { + detail::OpResultImpl *impl = result(idx).impl(); + IR_ENFORCE(impl->use_empty(), "operation destroyed but still has uses."); + if (detail::OpOutlineResultImpl::classof(*impl)) { + static_cast(impl)->~OpOutlineResultImpl(); + } else { + static_cast(impl)->~OpInlineResultImpl(); + } + } + + // 3. Deconstruct Operation. + this->~Operation(); + + // 4. Deconstruct OpOperand. + for (size_t idx = 0; idx < num_operands_; idx++) { + operand(idx).impl()->~OpOperandImpl(); + } + // 5. Free memory. uint32_t max_inline_result_num = detail::OpResultImpl::GetMaxInlineResultIndex() + 1; size_t result_mem_size = @@ -122,46 +140,11 @@ void Operation::Destroy() { (num_results_ - max_inline_result_num) + sizeof(detail::OpInlineResultImpl) * max_inline_result_num : sizeof(detail::OpInlineResultImpl) * num_results_; - char *aligned_ptr = reinterpret_cast(this) - result_mem_size; - // 2.1. Deconstruct OpResult. - char *base_ptr = aligned_ptr; - for (size_t idx = num_results_; idx > 0; idx--) { - // release the uses of this result - detail::OpOperandImpl *first_use = - reinterpret_cast(base_ptr)->first_use(); - while (first_use != nullptr) { - first_use->RemoveFromUdChain(); - first_use = - reinterpret_cast(base_ptr)->first_use(); - } - // destory the result - if (idx > max_inline_result_num) { - reinterpret_cast(base_ptr) - ->~OpOutlineResultImpl(); - base_ptr += sizeof(detail::OpOutlineResultImpl); - } else { - reinterpret_cast(base_ptr) - ->~OpInlineResultImpl(); - base_ptr += sizeof(detail::OpInlineResultImpl); - } - } - // 2.2. Deconstruct Operation. - if (reinterpret_cast(base_ptr) != - reinterpret_cast(this)) { - IR_THROW("Operation address error"); - } - reinterpret_cast(base_ptr)->~Operation(); - base_ptr += sizeof(Operation); - // 2.3. Deconstruct OpOperand. - for (size_t idx = 0; idx < num_operands_; idx++) { - reinterpret_cast(base_ptr)->~OpOperandImpl(); - base_ptr += sizeof(detail::OpOperandImpl); - } - // 3. Free memory. - VLOG(4) << "Destroy an Operation: {ptr = " - << reinterpret_cast(aligned_ptr) + void *aligned_ptr = reinterpret_cast(this) - result_mem_size; + + VLOG(4) << "Destroy an Operation: {ptr = " << aligned_ptr << ", size = " << result_mem_size << "}"; - aligned_free(reinterpret_cast(aligned_ptr)); + aligned_free(aligned_ptr); } IrContext *Operation::ir_context() const { return info_.ir_context(); } @@ -231,7 +214,7 @@ Program *Operation::GetParentProgram() { return module_op ? module_op.program() : nullptr; } -Region &Operation::GetRegion(unsigned index) { +Region &Operation::region(unsigned index) { assert(index < num_regions_ && "invalid region index"); return regions_[index]; } diff --git a/paddle/ir/core/operation.h b/paddle/ir/core/operation.h index 30f25d83c4a7cd..6a6d9dc19de5bb 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/ir/core/operation.h @@ -54,6 +54,9 @@ class IR_API alignas(8) Operation final { OpOperand operand(uint32_t index) const; + /// Returns the region held by this operation at position 'index'. + Region ®ion(unsigned index); + void Print(std::ostream &os); const AttributeMap &attributes() const { return attributes_; } @@ -95,11 +98,10 @@ class IR_API alignas(8) Operation final { Program *GetParentProgram(); - /// Returns the region held by this operation at position 'index'. - Region &GetRegion(unsigned index); - operator Block::iterator() { return position_; } + operator Block::const_iterator() const { return position_; } + private: Operation(const AttributeMap &attribute, ir::OpInfo op_info, diff --git a/paddle/ir/core/operation_utils.h b/paddle/ir/core/operation_utils.h index 7c012bcd0d5bc9..cbf19a4bb74c76 100644 --- a/paddle/ir/core/operation_utils.h +++ b/paddle/ir/core/operation_utils.h @@ -51,9 +51,15 @@ struct OperationArgument { info(info), regions(std::move(regions)) {} + /// Add Operand. + void AddOperand(OpResult operand) { inputs.emplace_back(operand); } + template void AddOperands(InputIt first, InputIt last); + /// Add Output. + void AddOutput(Type type) { output_types.emplace_back(type); } + template void AddTypes(InputIt first, InputIt last); diff --git a/paddle/ir/core/region.cc b/paddle/ir/core/region.cc index cc94d6936901f5..3d2cbe5be64913 100644 --- a/paddle/ir/core/region.cc +++ b/paddle/ir/core/region.cc @@ -14,6 +14,7 @@ #include "paddle/ir/core/region.h" #include "paddle/ir/core/block.h" +#include "paddle/ir/core/enforce.h" namespace ir { Region::~Region() { clear(); } @@ -29,6 +30,12 @@ Region::iterator Region::insert(const_iterator position, Block *block) { block->SetParent(this, iter); return iter; } + +Region::iterator Region::erase(const_iterator position) { + IR_ENFORCE((*position)->GetParent() == this, "iterator not own this region."); + delete *position; + return blocks_.erase(position); +} void Region::TakeBody(Region &&other) { clear(); blocks_.swap(other.blocks_); diff --git a/paddle/ir/core/region.h b/paddle/ir/core/region.h index b84ea97bfd9ed9..5d3c78e59a6001 100644 --- a/paddle/ir/core/region.h +++ b/paddle/ir/core/region.h @@ -48,6 +48,7 @@ class IR_API Region { void emplace_back(); void push_front(Block *block); iterator insert(const_iterator position, Block *block); + iterator erase(const_iterator position); void clear(); void TakeBody(Region &&other); diff --git a/paddle/ir/core/value.cc b/paddle/ir/core/value.cc index 118195718ff4f3..4e7afd2d835edb 100644 --- a/paddle/ir/core/value.cc +++ b/paddle/ir/core/value.cc @@ -34,17 +34,22 @@ OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) { } OpOperand::operator bool() const { return impl_ && impl_->source(); } -OpOperand OpOperand::next_use() const { return impl_->next_use(); } +OpOperand OpOperand::next_use() const { return impl()->next_use(); } -Value OpOperand::source() const { return impl_->source(); } +Value OpOperand::source() const { return impl()->source(); } -void OpOperand::set_source(Value value) { - IR_ENFORCE(impl_, "Can't set source for a null value."); - impl_->set_source(value); -} +Type OpOperand::type() const { return source().type(); } + +void OpOperand::set_source(Value value) { impl()->set_source(value); } + +Operation *OpOperand::owner() const { return impl()->owner(); } -Operation *OpOperand::owner() const { return impl_->owner(); } +void OpOperand::RemoveFromUdChain() { return impl()->RemoveFromUdChain(); } +detail::OpOperandImpl *OpOperand::impl() const { + IR_ENFORCE(impl_, "Can't use impl() interface while operand is null."); + return impl_; +} // Value Value::Value(const detail::ValueImpl *impl) : impl_(const_cast(impl)) {} @@ -84,13 +89,18 @@ void Value::ReplaceUsesWithIf( Value new_value, const std::function &should_replace) const { for (auto it = begin(); it != end();) { - auto cur = it++; - if (should_replace(*cur)) { - cur->set_source(new_value); + if (should_replace(*it)) { + (it++)->set_source(new_value); } } } +void Value::ReplaceAllUsesWith(Value new_value) const { + for (auto it = begin(); it != end();) { + (it++)->set_source(new_value); + } +} + detail::ValueImpl *Value::impl() const { IR_ENFORCE(impl_, "Can't use impl() interface while value is null."); return impl_; @@ -106,6 +116,7 @@ Operation *OpResult::owner() const { return impl()->owner(); } uint32_t OpResult::GetResultIndex() const { return impl()->GetResultIndex(); } detail::OpResultImpl *OpResult::impl() const { + IR_ENFORCE(impl_, "Can't use impl() interface while value is null."); return reinterpret_cast(impl_); } diff --git a/paddle/ir/core/value.h b/paddle/ir/core/value.h index 1499effbe8da37..7fa336fed4a4b4 100644 --- a/paddle/ir/core/value.h +++ b/paddle/ir/core/value.h @@ -55,11 +55,21 @@ class IR_API OpOperand { Value source() const; + Type type() const; + void set_source(Value value); Operation *owner() const; + void RemoveFromUdChain(); + + friend Operation; + private: + // The interface shoule ensure impl_ isn't nullptr. + // if the user can accept impl_ is nullptr, shoule use impl_ member directly. + detail::OpOperandImpl *impl() const; + detail::OpOperandImpl *impl_{nullptr}; }; @@ -155,6 +165,7 @@ class IR_API Value { void ReplaceUsesWithIf( Value new_value, const std::function &should_replace) const; + void ReplaceAllUsesWith(Value new_value) const; // The interface shoule ensure impl_ isn't nullptr. // if the user can accept impl_ is nullptr, shoule use impl_ member directly. diff --git a/paddle/ir/pass/pass.cc b/paddle/ir/pass/pass.cc index a3ccd178db1e16..0186ea892f0d6a 100644 --- a/paddle/ir/pass/pass.cc +++ b/paddle/ir/pass/pass.cc @@ -44,7 +44,7 @@ void detail::PassAdaptor::RunImpl(Operation* op, auto last_am = analysis_manager(); for (size_t i = 0; i < op->num_regions(); ++i) { - auto& region = op->GetRegion(i); + auto& region = op->region(i); for (auto it = region.begin(); it != region.end(); ++it) { auto* block = *it; for (auto it = block->begin(); it != block->end(); ++it) { diff --git a/test/cpp/ir/core/CMakeLists.txt b/test/cpp/ir/core/CMakeLists.txt index b3ab94ebe807f4..4987348bf82afe 100644 --- a/test/cpp/ir/core/CMakeLists.txt +++ b/test/cpp/ir/core/CMakeLists.txt @@ -2,6 +2,7 @@ cc_test_old(type_test SRCS type_test.cc DEPS ir gtest) cc_test_old(ir_attribute_test SRCS ir_attribute_test.cc DEPS ir gtest) cc_test_old(ir_value_test SRCS ir_value_test.cc DEPS ir gtest) cc_test_old(ir_op_test SRCS ir_op_test.cc DEPS ir gtest) +cc_test_old(ir_region_test SRCS ir_region_test.cc DEPS ir gtest) cc_test_old( ir_program_test SRCS diff --git a/test/cpp/ir/core/ir_program_test.cc b/test/cpp/ir/core/ir_program_test.cc index 57828c86c2122b..a55f3eeb347340 100644 --- a/test/cpp/ir/core/ir_program_test.cc +++ b/test/cpp/ir/core/ir_program_test.cc @@ -48,7 +48,21 @@ class AddOp : public ir::Op { throw("The size of outputs must be equal to 1."); } } + static void Build(ir::Builder &builder, // NOLINT + ir::OperationArgument &argument, // NOLINT + ir::OpResult l_operand, + ir::OpResult r_operand, + ir::Type sum_type); }; +void AddOp::Build(ir::Builder &, + ir::OperationArgument &argument, + ir::OpResult l_operand, + ir::OpResult r_operand, + ir::Type sum_type) { + argument.AddOperand(l_operand); + argument.AddOperand(r_operand); + argument.AddOutput(sum_type); +} IR_DECLARE_EXPLICIT_TYPE_ID(AddOp) IR_DEFINE_EXPLICIT_TYPE_ID(AddOp) @@ -90,22 +104,10 @@ TEST(program_test, program) { EXPECT_EQ(program.parameters_num() == 2, true); // (4) Def a = GetParameterOp("a"), and create DenseTensor for a. - std::string op1_name = ir::GetParameterOp::name(); - ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name); - std::unordered_map op1_attribute{ - {"parameter_name", ir::StrAttribute::get(ctx, "a")}}; - ir::Operation *op1 = - ir::Operation::Create({}, op1_attribute, {dense_tensor_dtype}, op1_info); - - ir::Block *block = program.block(); - block->push_back(op1); - - EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParent()); - - EXPECT_EQ(program.module_op(), block->GetParentOp()); + ir::Builder builder(ctx, program.block()); + auto op1 = builder.Build("a", dense_tensor_dtype); EXPECT_EQ(&program, op1->GetParentProgram()); - EXPECT_EQ(op1->result(0).type().dialect().id(), paddle_dialect->id()); using Interface = paddle::dialect::ParameterConvertInterface; Interface *a_interface = @@ -124,14 +126,7 @@ TEST(program_test, program) { } // (5) Def b = GetParameterOp("b"), and create DenseTensor for b. - std::string op2_name = - builtin_dialect->name() + "." + std::string(ir::GetParameterOp::name()); - ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name); - std::unordered_map op2_attribute{ - {"parameter_name", ir::StrAttribute::get(ctx, "b")}}; - ir::Operation *op2 = - ir::Operation::Create({}, op2_attribute, {dense_tensor_dtype}, op2_info); - block->push_back(op2); + auto op2 = builder.Build("b", dense_tensor_dtype); EXPECT_EQ(op2->result(0).type().dialect().id(), paddle_dialect->id()); Interface *b_interface = @@ -150,16 +145,8 @@ TEST(program_test, program) { } // (6) Def c = AddOp(a, b), execute this op. - std::string op3_name = - builtin_dialect->name() + "." + std::string(AddOp::name()); - ir::OpInfo op3_info = ctx->GetRegisteredOpInfo(op3_name); - std::unordered_map op3_attribute; - ir::Operation *op3 = ir::Operation::Create({op1->result(0), op2->result(0)}, - op3_attribute, - {dense_tensor_dtype}, - op3_info); - block->push_back(op3); - + auto op3 = + builder.Build(op1->result(0), op2->result(0), dense_tensor_dtype); phi::CPUContext *dev_ctx = static_cast( paddle::platform::DeviceContextPool::Instance().Get( paddle::platform::CPUPlace())); @@ -180,38 +167,17 @@ TEST(program_test, program) { } // (7) Def AbsOp(b) - ir::OpInfo abs_info = ctx->GetRegisteredOpInfo("pd.abs"); - std::vector operands = {op1->result(0)}; - std::unordered_map abs_op_attribute; - std::vector output_types = {dense_tensor_dtype}; - ir::OperationArgument abs_argument(abs_info); - abs_argument.AddOperands(operands.begin(), operands.end()); - abs_argument.AddAttributes(abs_op_attribute.begin(), abs_op_attribute.end()); - abs_argument.AddTypes(output_types.begin(), output_types.end()); - ir::Operation *abs_op = ir::Operation::Create(std::move(abs_argument)); + auto abs_op = builder.Build(op1->result(0)); paddle::dialect::OpYamlInfoInterface interface = abs_op->dyn_cast(); EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true); // (8) Def SetParameterOp(c, "c") - std::string op4_name = - builtin_dialect->name() + "." + std::string(ir::SetParameterOp::name()); - ir::OpInfo op4_info = ctx->GetRegisteredOpInfo(op4_name); - std::unordered_map op4_attribute{ - {"parameter_name", ir::StrAttribute::get(ctx, "c")}}; - - ir::OperationArgument op4_argument({op3->result(0)}, {}, {}, op4_info); - op4_argument.AddAttributes(op4_attribute.begin(), op4_attribute.end()); - ir::Operation *op4 = ir::Operation::Create(std::move(op4_argument)); - block->push_back(op4); + auto op4 = builder.Build(op3->result(0), "c"); - EXPECT_EQ(op4->operand(0).source().type().dialect().id(), - paddle_dialect->id()); - Interface *c_interface = op4->operand(0) - .source() - .type() - .dialect() - .GetRegisteredInterface(); + EXPECT_EQ(op4->operand(0).type().dialect().id(), paddle_dialect->id()); + Interface *c_interface = + op4->operand(0).type().dialect().GetRegisteredInterface(); // ir::Parameter *parameter_c = // c_interface->VariableToParameter(variable_c.get()); std::unique_ptr parameter_c = @@ -224,7 +190,7 @@ TEST(program_test, program) { program.SetParameter("c", std::move(parameter_c)); // (8) Traverse Program - EXPECT_EQ(program.block()->size() == 4, true); + EXPECT_EQ(program.block()->size() == 5, true); EXPECT_EQ(program.parameters_num() == 3, true); program.Print(std::cout); diff --git a/test/cpp/ir/core/ir_region_test.cc b/test/cpp/ir/core/ir_region_test.cc new file mode 100644 index 00000000000000..33c6144fe77daa --- /dev/null +++ b/test/cpp/ir/core/ir_region_test.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/ir/core/block.h" +#include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/core/builtin_dialect.h" +#include "paddle/ir/core/builtin_op.h" +#include "paddle/ir/core/builtin_type.h" +#include "paddle/ir/core/ir_context.h" +#include "paddle/ir/core/program.h" +#include "paddle/ir/core/utils.h" + +TEST(region, erase_op_test) { + // (1) Init environment. + ir::IrContext* ctx = ir::IrContext::Instance(); + + // (2) Create an empty program object + ir::Program program(ctx); + ir::Builder builder = ir::Builder(ctx, program.block()); + + // (3) Def a = ConstantOp("2.0"); b = ConstantOp("2.0"); + ir::FloatAttribute fp_attr = ir::FloatAttribute::get(ctx, 2.0f); + ir::Float32Type fp32_type = ir::Float32Type::get(ctx); + ir::OpResult a = builder.Build(fp_attr, fp32_type)->result(0); + ir::OpResult b = builder.Build(fp_attr, fp32_type)->result(0); + + // (6) Def c = CombineOp(a, b) + builder.Build(std::vector{a, b}); + + // Test ir::Block::erase + ir::Block* block = program.block(); + EXPECT_EQ(block->size(), 3u); + block->erase(*(block->back())); + EXPECT_EQ(block->size(), 2u); + + // Test ir::Region::erase + ir::Region& region = program.module_op()->region(0); + region.push_back(new ir::Block()); + EXPECT_EQ(region.size(), 2u); + region.erase(region.begin()); + EXPECT_EQ(region.size(), 1u); +} diff --git a/test/cpp/ir/core/ir_value_test.cc b/test/cpp/ir/core/ir_value_test.cc index 8ff759ef03850b..ae1edfe685dd2b 100644 --- a/test/cpp/ir/core/ir_value_test.cc +++ b/test/cpp/ir/core/ir_value_test.cc @@ -104,10 +104,14 @@ TEST(value_test, value_test) { // Test 4: Value Replace Use // a = OP1(); b = OP2(); c = OP3(a, b); d, e, f, g, h, i, j = OP4(a, c); // - c.ReplaceUsesWithIf(a, [](ir::OpOperand) { return true; }); - EXPECT_EQ(op4->operand(1).source(), a); + c.ReplaceUsesWithIf(b, [](ir::OpOperand) { return true; }); + EXPECT_EQ(op4->operand(1).source(), b); EXPECT_TRUE(c.use_empty()); + b.ReplaceAllUsesWith(a); + EXPECT_EQ(op4->operand(1).source(), a); + EXPECT_TRUE(b.use_empty()); + // destroy VLOG(0) << op1->result(0).PrintUdChain() << std::endl; op4->Destroy(); diff --git a/test/cpp/ir/pass/pass_manager_test.cc b/test/cpp/ir/pass/pass_manager_test.cc index fa12303d69d866..0b3aa35829fab7 100644 --- a/test/cpp/ir/pass/pass_manager_test.cc +++ b/test/cpp/ir/pass/pass_manager_test.cc @@ -45,7 +45,21 @@ class AddOp : public ir::Op { throw("The size of outputs must be equal to 1."); } } + static void Build(ir::Builder &builder, // NOLINT + ir::OperationArgument &argument, // NOLINT + ir::OpResult l_operand, + ir::OpResult r_operand, + ir::Type sum_type); }; +void AddOp::Build(ir::Builder &, + ir::OperationArgument &argument, + ir::OpResult l_operand, + ir::OpResult r_operand, + ir::Type sum_type) { + argument.AddOperand(l_operand); + argument.AddOperand(r_operand); + argument.AddOutput(sum_type); +} IR_DECLARE_EXPLICIT_TYPE_ID(AddOp) IR_DEFINE_EXPLICIT_TYPE_ID(AddOp) @@ -79,10 +93,9 @@ TEST(pass_manager_test, pass_manager) { // (3) Create a float32 DenseTensor Parameter and save into Program ir::Type fp32_dtype = ir::Float32Type::get(ctx); - paddle::dialect::DenseTensorTypeStorage::Dim dims = {2, 2}; - paddle::dialect::DenseTensorTypeStorage::DataLayout data_layout = - paddle::dialect::DenseTensorTypeStorage::DataLayout::NCHW; - paddle::dialect::DenseTensorTypeStorage::LoD lod = {{0, 1, 2}}; + phi::DDim dims = {2, 2}; + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{0, 1, 2}}; size_t offset = 0; ir::Type dense_tensor_dtype = paddle::dialect::DenseTensorType::get( ctx, fp32_dtype, dims, data_layout, lod, offset); @@ -104,22 +117,10 @@ TEST(pass_manager_test, pass_manager) { EXPECT_EQ(program.parameters_num() == 2, true); // (4) Def a = GetParameterOp("a"), and create DenseTensor for a. - std::string op1_name = ir::GetParameterOp::name(); - ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name); - std::unordered_map op1_attribute{ - {"parameter_name", ir::StrAttribute::get(ctx, "a")}}; - ir::Operation *op1 = - ir::Operation::Create({}, op1_attribute, {dense_tensor_dtype}, op1_info); - - ir::Block *block = program.block(); - block->push_back(op1); - - EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParent()); - - EXPECT_EQ(program.module_op(), block->GetParentOp()); + ir::Builder builder(ctx, program.block()); + auto op1 = builder.Build("a", dense_tensor_dtype); EXPECT_EQ(&program, op1->GetParentProgram()); - EXPECT_EQ(op1->result(0).type().dialect().id(), paddle_dialect->id()); using Interface = paddle::dialect::ParameterConvertInterface; Interface *a_interface = @@ -138,15 +139,7 @@ TEST(pass_manager_test, pass_manager) { } // (5) Def b = GetParameterOp("b"), and create DenseTensor for b. - std::string op2_name = - builtin_dialect->name() + "." + std::string(ir::GetParameterOp::name()); - ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name); - std::unordered_map op2_attribute{ - {"parameter_name", ir::StrAttribute::get(ctx, "b")}}; - ir::Operation *op2 = - ir::Operation::Create({}, op2_attribute, {dense_tensor_dtype}, op2_info); - block->push_back(op2); - + auto op2 = builder.Build("b", dense_tensor_dtype); EXPECT_EQ(op2->result(0).type().dialect().id(), paddle_dialect->id()); Interface *b_interface = op2->result(0).type().dialect().GetRegisteredInterface(); @@ -164,16 +157,8 @@ TEST(pass_manager_test, pass_manager) { } // (6) Def c = AddOp(a, b), execute this op. - std::string op3_name = - builtin_dialect->name() + "." + std::string(AddOp::name()); - ir::OpInfo op3_info = ctx->GetRegisteredOpInfo(op3_name); - std::unordered_map op3_attribute; - ir::Operation *op3 = ir::Operation::Create({op1->result(0), op2->result(0)}, - op3_attribute, - {dense_tensor_dtype}, - op3_info); - block->push_back(op3); - + auto op3 = + builder.Build(op1->result(0), op2->result(0), dense_tensor_dtype); phi::CPUContext *dev_ctx = static_cast( paddle::platform::DeviceContextPool::Instance().Get( paddle::platform::CPUPlace())); @@ -193,39 +178,12 @@ TEST(pass_manager_test, pass_manager) { EXPECT_EQ(*(dst_tensor->data() + i), data_a[i] + data_b[i]); } - // (7) Def AbsOp(b) - ir::OpInfo abs_info = ctx->GetRegisteredOpInfo("pd.abs"); - std::vector operands = {op1->result(0)}; - std::unordered_map abs_op_attribute; - std::vector output_types = {dense_tensor_dtype}; - ir::OperationArgument abs_argument(abs_info); - abs_argument.AddOperands(operands.begin(), operands.end()); - abs_argument.AddAttributes(abs_op_attribute.begin(), abs_op_attribute.end()); - abs_argument.AddTypes(output_types.begin(), output_types.end()); - ir::Operation *abs_op = ir::Operation::Create(std::move(abs_argument)); - paddle::dialect::OpYamlInfoInterface interface = - abs_op->dyn_cast(); - EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true); - - // (8) Def SetParameterOp(c, "c") - std::string op4_name = - builtin_dialect->name() + "." + std::string(ir::SetParameterOp::name()); - ir::OpInfo op4_info = ctx->GetRegisteredOpInfo(op4_name); - std::unordered_map op4_attribute{ - {"parameter_name", ir::StrAttribute::get(ctx, "c")}}; - - ir::OperationArgument op4_argument({op3->result(0)}, {}, {}, op4_info); - op4_argument.AddAttributes(op4_attribute.begin(), op4_attribute.end()); - ir::Operation *op4 = ir::Operation::Create(std::move(op4_argument)); - block->push_back(op4); - + // (7) Def SetParameterOp(c, "c") + auto op4 = builder.Build(op3->result(0), "c"); EXPECT_EQ(op4->operand(0).source().type().dialect().id(), paddle_dialect->id()); - Interface *c_interface = op4->operand(0) - .source() - .type() - .dialect() - .GetRegisteredInterface(); + Interface *c_interface = + op4->operand(0).type().dialect().GetRegisteredInterface(); // ir::Parameter *parameter_c = // c_interface->VariableToParameter(variable_c.get()); std::unique_ptr parameter_c =