From bd9fb6145a1988c47ef24b23b2518e01f2253b1d Mon Sep 17 00:00:00 2001 From: winter-wang <1030748926@qq.com> Date: Tue, 9 Jan 2024 16:27:51 +0000 Subject: [PATCH] [PIR] polish the ir_mapping implimentation. --- paddle/fluid/pybind/pir.cc | 2 +- paddle/pir/core/ir_mapping.h | 84 +++++++++++++++++++++--------------- paddle/pir/core/operation.cc | 3 +- 3 files changed, 51 insertions(+), 38 deletions(-) diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index c4c9eb0145aa59..bbeddda969c154 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -1030,7 +1030,7 @@ std::pair, OpResultMap> CloneProgram( pir::IrMapping mapper; auto cloned_program = program.Clone(mapper); std::vector associated_array_key, associated_array_value; - for (auto &pair : mapper.Map()) { + for (auto &pair : mapper.GetMap()) { associated_array_key.push_back(pair.first.dyn_cast()); associated_array_value.push_back(pair.second.dyn_cast()); } diff --git a/paddle/pir/core/ir_mapping.h b/paddle/pir/core/ir_mapping.h index d33ad49962c2ca..9eeed8988bb9cc 100644 --- a/paddle/pir/core/ir_mapping.h +++ b/paddle/pir/core/ir_mapping.h @@ -20,24 +20,65 @@ namespace pir { class Block; class Operation; +namespace detail { +template +struct ExactlyOneIrType { + using type = void; +}; +template +struct ExactlyOneIrType { + using type = + std::conditional_t::value, + FirstT, + typename ExactlyOneIrType::type>; +}; +} // namespace detail class IrMapping { public: template - void Add(T from, T to) { + using IrType = + typename detail::ExactlyOneIrType::type; + template + std::unordered_map& GetMap() { + if constexpr (std::is_same::value) { + return value_map_; + } else if constexpr (std::is_same::value) { + return block_map_; + } else if constexpr (std::is_same::value) { + return operation_map_; + } else { + IR_THROW("Not support type in IRMapping."); + } + } + template + const std::unordered_map& GetMap() const { + if constexpr (std::is_same::value) { + return value_map_; + } else if constexpr (std::is_same::value) { + return block_map_; + } else if constexpr (std::is_same::value) { + return operation_map_; + } else { + IR_THROW("Not support type in IRMapping."); + } + } + template + void Add(T from, S to) { if (!from) return; - MutableMap()[from] = to; + GetMap>()[from] = to; } template T Lookup(T from) const { if (!from) return static_cast(nullptr); - IR_ENFORCE(Map().count(from) > 0, "Not found key in IRMapping."); - return Map().at(from); + IR_ENFORCE(GetMap>().count(from) > 0, + "Not found key in IRMapping."); + return GetMap>().at(from); } template void Earse(T from) { - MutableMap().erase(from); + GetMap>().erase(from); } void Clear() { @@ -46,37 +87,10 @@ class IrMapping { operation_map_.clear(); } - template - using MapType = std::unordered_map; - - template - const MapType &Map() const { - if constexpr (std::is_convertible::value) - return value_map_; - else if constexpr (std::is_convertible::value) - return block_map_; - else if constexpr (std::is_convertible::value) - return operation_map_; - else - IR_THROW("Not support type in IRMapping."); - } - - template - MapType &MutableMap() { - if constexpr (std::is_convertible::value) - return value_map_; - else if constexpr (std::is_convertible::value) - return block_map_; - else if constexpr (std::is_convertible::value) - return operation_map_; - else - IR_THROW("Not support type in IRMapping."); - } - private: - MapType value_map_; - MapType block_map_; - MapType operation_map_; + std::unordered_map value_map_; + std::unordered_map block_map_; + std::unordered_map operation_map_; }; } // namespace pir diff --git a/paddle/pir/core/operation.cc b/paddle/pir/core/operation.cc index 00db9ce80fc148..c18bcfc09a9f36 100644 --- a/paddle/pir/core/operation.cc +++ b/paddle/pir/core/operation.cc @@ -158,8 +158,7 @@ Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) { // record outputs mapping info for (uint32_t i = 0; i < num_results_; ++i) { - ir_mapping.Add(static_cast(result(i)), - static_cast(new_op->result(i))); + ir_mapping.Add(result(i), new_op->result(i)); } if (options.IsCloneRegions()) {