diff --git a/paddle/cinn/hlir/framework/pir/group.h b/paddle/cinn/hlir/framework/pir/group.h index 535260c2ee96b..79fe860d2ab8f 100644 --- a/paddle/cinn/hlir/framework/pir/group.h +++ b/paddle/cinn/hlir/framework/pir/group.h @@ -67,7 +67,7 @@ struct Group { std::vector<::pir::Operation*> new_ops; // Mapper from original to new ops. std::unordered_map<::pir::Operation*, ::pir::Operation*> ops_mapper; - ::pir::CloneOptions clone_options(false, true); + auto clone_options = ::pir::CloneOptions(false, true, false); for (auto* op : ops) { VLOG(4) << "clone op :" << op->name(); auto* new_op = op->Clone(ir_mapping, clone_options); diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 300ff3c7f7ccc..1eedb38b1c777 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -1024,7 +1024,7 @@ static auto GetNoNeedBufferValue(const ::pir::Block *whole_block, using OpResultMap = std::pair, std::vector>; std::pair, OpResultMap> CloneProgram( - Program &program) { // NOLINT + const Program &program) { // Limitation of this function: // 1. don't support Parameters. pir::IrMapping mapper; @@ -1111,7 +1111,7 @@ SplitedResult SplitForwardBackward( // forward program construct. VLOG(4) << "start create forward program."; pir::IrMapping forward_mapper; - auto clone_options = pir::CloneOptions(true, true); + auto clone_options = pir::CloneOptions::All(); range_block_do( program.block(), forward_range, diff --git a/paddle/pir/core/ir_mapping.h b/paddle/pir/core/ir_mapping.h index 2c05af2e8975c..973179693888b 100644 --- a/paddle/pir/core/ir_mapping.h +++ b/paddle/pir/core/ir_mapping.h @@ -33,13 +33,22 @@ struct ExactlyOneIrType { typename ExactlyOneIrType::type>; }; } // namespace detail + class IrMapping { public: template - using IrType = - typename detail::ExactlyOneIrType::type; + using remove_lowlevel_const_t = std::conditional_t< + std::is_pointer::value, + std::add_pointer_t>>, + T>; + template + using IrType = typename detail::ExactlyOneIrType, + Value, + Block*, + Operation*>::type; + template - std::unordered_map& GetMutableMap() { + auto& GetMutableMap() { if constexpr (std::is_same::value) { return value_map_; } else if constexpr (std::is_same::value) { @@ -50,8 +59,9 @@ class IrMapping { IR_THROW("Not support type in IRMapping."); } } + template - const std::unordered_map& GetMap() const { + const auto& GetMap() const { if constexpr (std::is_same::value) { return value_map_; } else if constexpr (std::is_same::value) { @@ -62,6 +72,7 @@ class IrMapping { IR_THROW("Not support type in IRMapping."); } } + template void Add(T from, S to) { if (!from) return; @@ -69,8 +80,8 @@ class IrMapping { } template - T Lookup(T from) const { - if (!from) return static_cast(nullptr); + IrType Lookup(T from) const { + if (!from) return static_cast>(nullptr); IR_ENFORCE(GetMap>().count(from) > 0, "Not found key in IRMapping."); return GetMap>().at(from); @@ -89,8 +100,8 @@ class IrMapping { private: std::unordered_map value_map_; - std::unordered_map block_map_; - std::unordered_map operation_map_; + std::unordered_map block_map_; + std::unordered_map operation_map_; }; } // namespace pir diff --git a/paddle/pir/core/operation.cc b/paddle/pir/core/operation.cc index e5b978e58bf12..b6cc4cf789ddb 100644 --- a/paddle/pir/core/operation.cc +++ b/paddle/pir/core/operation.cc @@ -137,10 +137,7 @@ Operation *Operation::Create(const std::vector &inputs, return op; } -Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) { - IR_ENFORCE(num_successors_ == 0, - "Operation::Clone is not unimplemented for multiple successors."); - +Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) const { auto inputs = operands_source(); if (options.IsCloneOperands()) { // replace value by IRMapping inplacely. @@ -153,7 +150,15 @@ Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) { for (auto &result : results()) { output_types.push_back(result.type()); } - auto *new_op = Create(inputs, attributes_, output_types, info_, num_regions_); + + std::vector successors = {}; + if (options.IsCloneSuccessors()) { + for (uint32_t i = 0; i < num_successors_; ++i) { + successors.push_back(ir_mapping.Lookup(successor(i))); + } + } + auto *new_op = Create( + inputs, attributes_, output_types, info_, num_regions_, successors); ir_mapping.Add(this, new_op); // record outputs mapping info @@ -246,7 +251,7 @@ Operation::Operation(const AttributeMap &attributes, /// /// \brief op ouput related public interfaces implementation /// -std::vector Operation::results() { +std::vector Operation::results() const { std::vector res; for (uint32_t i = 0; i < num_results(); ++i) { res.push_back(result(i)); diff --git a/paddle/pir/core/operation.h b/paddle/pir/core/operation.h index 0dafcdd7a5b2b..610cf70948939 100644 --- a/paddle/pir/core/operation.h +++ b/paddle/pir/core/operation.h @@ -39,16 +39,28 @@ class OpOperendImpl; class CloneOptions { public: - CloneOptions() : clone_regions_{false}, clone_operands_{false} {} - CloneOptions(bool clone_regions, bool clone_operands) - : clone_regions_(clone_regions), clone_operands_(clone_operands) {} + CloneOptions() + : clone_regions_{false}, + clone_operands_{false}, + clone_successors_{false} {} + CloneOptions(bool clone_regions, bool clone_operands, bool clone_successors) + : clone_regions_(clone_regions), + clone_operands_(clone_operands), + clone_successors_(clone_successors) {} bool IsCloneRegions() const { return clone_regions_; } bool IsCloneOperands() const { return clone_operands_; } + bool IsCloneSuccessors() const { return clone_successors_; } + + static CloneOptions &All() { + static CloneOptions all{true, true, true}; + return all; + } private: bool clone_regions_{true}; bool clone_operands_{true}; + bool clone_successors_{true}; }; class IR_API alignas(8) Operation final @@ -72,7 +84,7 @@ class IR_API alignas(8) Operation final /// \brief Deep copy all information and create a new operation. /// Operation *Clone(IrMapping &ir_mapping, - CloneOptions options = CloneOptions()); + CloneOptions options = CloneOptions()) const; /// /// \brief Destroy the operation objects and free memory by create(). /// @@ -114,7 +126,7 @@ class IR_API alignas(8) Operation final T result_type(uint32_t index) const { return result(index).type().dyn_cast(); } - std::vector results(); + std::vector results() const; /// /// \brief op input related public interfaces diff --git a/paddle/pir/core/program.cc b/paddle/pir/core/program.cc index 7a5dc0a6e987b..e79b2758554c2 100644 --- a/paddle/pir/core/program.cc +++ b/paddle/pir/core/program.cc @@ -27,11 +27,11 @@ Program::~Program() { } } -std::shared_ptr Program::Clone(IrMapping& ir_mapping) { +std::shared_ptr Program::Clone(IrMapping& ir_mapping) const { pir::IrContext* ctx = pir::IrContext::Instance(); auto new_program = std::make_shared(ctx); - auto clone_options = CloneOptions(true, true); - for (auto& op : *block()) { + auto clone_options = CloneOptions::All(); + for (const auto& op : *block()) { auto* new_op = op.Clone(ir_mapping, clone_options); new_program->block()->push_back(new_op); } diff --git a/paddle/pir/core/program.h b/paddle/pir/core/program.h index ac65e57e2ca08..dabd871e519a5 100644 --- a/paddle/pir/core/program.h +++ b/paddle/pir/core/program.h @@ -55,7 +55,7 @@ class IR_API Program { static std::unique_ptr Parse(std::istream& is, IrContext* ctx); - std::shared_ptr Clone(IrMapping& ir_mapping); // NOLINT + std::shared_ptr Clone(IrMapping& ir_mapping) const; // NOLINT Block* block() { return &module_.block(); } const Block* block() const { return &module_op().block(); } diff --git a/paddle/pir/core/region.cc b/paddle/pir/core/region.cc index 9d72ddca0b279..8adce14977abb 100644 --- a/paddle/pir/core/region.cc +++ b/paddle/pir/core/region.cc @@ -42,23 +42,30 @@ Region::Iterator Region::erase(ConstIterator position) { return blocks_.erase(position); } -void Region::CloneInto(Region &other, IrMapping &ir_mapping) { +void Region::CloneInto(Region &other, IrMapping &ir_mapping) const { if (empty()) { return; } other.clear(); - auto clone_options = CloneOptions(false, false); // clone blocks, block arguments and sub operations - for (auto &block : *this) { + for (const auto &block : *this) { auto new_block = new Block; ir_mapping.Add(&block, new_block); - for (auto &arg : block.args()) { + for (const auto &arg : block.args()) { ir_mapping.Add(arg, new_block->AddArgument(arg.type())); } other.push_back(new_block); - // clone sub operations, but not map operands nor clone regions - for (auto op_iter = block.begin(); op_iter != block.end(); ++op_iter) { - new_block->push_back(op_iter->Clone(ir_mapping, clone_options)); + } + // clone sub operations, but not map operands nor clone regions + { + auto clone_options = CloneOptions(false, false, true); + auto iter = begin(); + auto new_iter = other.begin(); + for (; iter != end(); ++iter, ++new_iter) { + const Block &block = *iter; + Block &new_block = *new_iter; + for (const auto &op : block) + new_block.push_back(op.Clone(ir_mapping, clone_options)); } } // after all operation results are mapped, map operands and clone regions. @@ -69,7 +76,7 @@ void Region::CloneInto(Region &other, IrMapping &ir_mapping) { auto op_iter = iter->begin(); auto new_op_iter = new_iter->begin(); for (; op_iter != iter->end(); ++op_iter, ++new_op_iter) { - Operation &op = *op_iter; + const Operation &op = *op_iter; Operation &new_op = *new_op_iter; // operands of new_op are same as op, now map them. for (uint32_t i = 0; i < op.num_operands(); ++i) diff --git a/paddle/pir/core/region.h b/paddle/pir/core/region.h index 090473b7fb986..54129bb02433c 100644 --- a/paddle/pir/core/region.h +++ b/paddle/pir/core/region.h @@ -73,7 +73,7 @@ class IR_API Region { void Walk(FuncT &&callback); // clone this region into another region, target region will be overwritten. - void CloneInto(Region &other, IrMapping &ir_mapping); // NOLINT + void CloneInto(Region &other, IrMapping &ir_mapping) const; // NOLINT // take the last block of region. // if region is empty, return nullptr; diff --git a/test/cpp/pir/core/ir_region_test.cc b/test/cpp/pir/core/ir_region_test.cc index 055db2f841d8d..10ea2e599d1e8 100644 --- a/test/cpp/pir/core/ir_region_test.cc +++ b/test/cpp/pir/core/ir_region_test.cc @@ -80,7 +80,7 @@ TEST(region, clone_op_test) { pir::Operation &op = *program.module_op(); pir::Block &block = op.region(0).front(); pir::IrMapping mapper; - pir::Operation &new_op = *op.Clone(mapper, pir::CloneOptions(true, true)); + pir::Operation &new_op = *op.Clone(mapper, pir::CloneOptions::All()); // (8) Check the cloned op recursively EXPECT_EQ(mapper.Lookup(&op), &new_op); @@ -103,8 +103,7 @@ TEST(region, clone_op_test) { } EXPECT_EQ(op.num_results(), new_op.num_results()); for (uint32_t i = 0; i < op.num_results(); ++i) { - EXPECT_EQ(mapper.Lookup(static_cast(op.result(i))), - static_cast(new_op.result(i))); + EXPECT_EQ(mapper.Lookup(op.result(i)), new_op.result(i)); } EXPECT_TRUE(std::equal(op.attributes().begin(), op.attributes().end(),