diff --git a/src/vt/collective/reduce/allreduce/allreduce_holder.cc b/src/vt/collective/reduce/allreduce/allreduce_holder.cc index 38cf2a71b4..5395f38c67 100644 --- a/src/vt/collective/reduce/allreduce/allreduce_holder.cc +++ b/src/vt/collective/reduce/allreduce/allreduce_holder.cc @@ -55,8 +55,10 @@ objgroup::proxy::Proxy AllreduceHolder::addRabensifnerAllreducer( col_reducers_[coll_proxy].first = obj_proxy.getProxy(); - fmt::print( - "Adding new Rabenseifner reducer for collection={:x}\n", coll_proxy); + vt_debug_print( + verbose, allreduce, "Adding new Rabenseifner reducer for collection={:x}", + coll_proxy + ); return obj_proxy; } @@ -70,8 +72,11 @@ AllreduceHolder::addRecursiveDoublingAllreducer( "recursive_doubling_allreducer", strong_proxy, strong_group, num_elems); col_reducers_[coll_proxy].second = obj_proxy.getProxy(); - fmt::print( - "Adding new RecursiveDoubling reducer for collection={:x}\n", coll_proxy); + + vt_debug_print( + verbose, allreduce, + "Adding new RecursiveDoubling reducer for collection={:x}", coll_proxy + ); return obj_proxy; } @@ -85,9 +90,10 @@ AllreduceHolder::addRabensifnerAllreducer(detail::StrongGroup strong_group) { group_reducers_[group].first = obj_proxy.getProxy(); - fmt::print( - "Adding new Rabenseifner reducer for group={:x} Size={}\n", group, - group_reducers_.size()); + vt_debug_print( + verbose, allreduce, + "Adding new Rabenseifner reducer for group={:x}", group + ); return obj_proxy; } @@ -100,13 +106,51 @@ AllreduceHolder::addRecursiveDoublingAllreducer( auto obj_proxy = theObjGroup()->makeCollective( "recursive_doubling_allreducer", strong_group); - fmt::print("Adding new RecursiveDoubling reducer for group={:x}\n", group); + vt_debug_print( + verbose, allreduce, + "Adding new Rabenseifner reducer for group={:x}", group + ); group_reducers_[group].second = obj_proxy.getProxy(); return obj_proxy; } +objgroup::proxy::Proxy +AllreduceHolder::addRabensifnerAllreducer(detail::StrongObjGroup strong_objgroup) { + auto const objgroup = strong_objgroup.get(); + + auto obj_proxy = theObjGroup()->makeCollective( + "rabenseifer_allreducer", strong_objgroup); + + objgroup_reducers_[objgroup].first = obj_proxy.getProxy(); + + vt_debug_print( + verbose, allreduce, + "Adding new Rabenseifner reducer for objgroup={:x} Size={}\n", objgroup + ); + + return obj_proxy; +} + +objgroup::proxy::Proxy +AllreduceHolder::addRecursiveDoublingAllreducer( + detail::StrongObjGroup strong_objgroup) { + auto const objgroup = strong_objgroup.get(); + + auto obj_proxy = theObjGroup()->makeCollective( + "recursive_doubling_allreducer", strong_objgroup); + + vt_debug_print( + verbose, allreduce, + "Adding new RecursiveDoubling reducer for objgroup={:x}\n", objgroup + ); + + objgroup_reducers_[objgroup].second = obj_proxy.getProxy(); + + return obj_proxy; +} + void AllreduceHolder::remove(detail::StrongVrtProxy strong_proxy) { auto const key = strong_proxy.get(); @@ -127,4 +171,14 @@ void AllreduceHolder::remove(detail::StrongGroup strong_group) { } } +void AllreduceHolder::remove(detail::StrongObjGroup strong_objgroup) { + auto const key = strong_objgroup.get(); + + auto it = objgroup_reducers_.find(key); + + if (it != objgroup_reducers_.end()) { + objgroup_reducers_.erase(key); + } +} + } // namespace vt::collective::reduce::allreduce diff --git a/src/vt/collective/reduce/allreduce/allreduce_holder.h b/src/vt/collective/reduce/allreduce/allreduce_holder.h index b60a232a0c..3bcb7f4eeb 100644 --- a/src/vt/collective/reduce/allreduce/allreduce_holder.h +++ b/src/vt/collective/reduce/allreduce/allreduce_holder.h @@ -118,6 +118,33 @@ struct AllreduceHolder { } } + template + static auto getAllreducer(detail::StrongObjGroup strong_objgroup) { + auto const objgroup = strong_objgroup.get(); + + if (auto it = objgroup_reducers_.find(objgroup); it == objgroup_reducers_.end()) { + objgroup_reducers_[objgroup] = {u64empty, u64empty}; + } + + if constexpr (std::is_same_v) { + auto untyped_proxy = objgroup_reducers_.at(objgroup).first; + if (untyped_proxy == u64empty) { + return addRabensifnerAllreducer(strong_objgroup); + } else { + return static_cast>( + untyped_proxy); + } + } else { + auto untyped_proxy = objgroup_reducers_.at(objgroup).second; + if (untyped_proxy == u64empty) { + return addRecursiveDoublingAllreducer(strong_objgroup); + } else { + return static_cast>( + untyped_proxy); + } + } + } + static objgroup::proxy::Proxy addRabensifnerAllreducer( detail::StrongVrtProxy strong_proxy, detail::StrongGroup strong_group, size_t num_elems); @@ -132,8 +159,14 @@ struct AllreduceHolder { static objgroup::proxy::Proxy addRecursiveDoublingAllreducer(detail::StrongGroup strong_group); + static objgroup::proxy::Proxy + addRabensifnerAllreducer(detail::StrongObjGroup strong_group); + static objgroup::proxy::Proxy + addRecursiveDoublingAllreducer(detail::StrongObjGroup strong_group); + static void remove(detail::StrongVrtProxy strong_proxy); static void remove(detail::StrongGroup strong_group); + static void remove(detail::StrongObjGroup strong_group); static inline std::unordered_map< VirtualProxyType, std::pair> diff --git a/src/vt/collective/reduce/allreduce/rabenseifner.cc b/src/vt/collective/reduce/allreduce/rabenseifner.cc index 9ab4443e31..9edd45a7a8 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner.cc +++ b/src/vt/collective/reduce/allreduce/rabenseifner.cc @@ -42,6 +42,7 @@ */ #include "vt/collective/reduce/allreduce/rabenseifner.h" +#include "vt/collective/reduce/allreduce/allreduce_holder.h" #include "vt/configs/error/config_assert.h" #include "vt/group/group_manager.h" @@ -67,16 +68,7 @@ Rabenseifner::Rabenseifner( } is_even_ = this_node_ % 2 == 0; - is_part_of_adjustment_group_ = this_node_ < (2 * nprocs_rem_); - if (is_part_of_adjustment_group_) { - if (is_even_) { - vrt_node_ = this_node_ / 2; - } else { - vrt_node_ = -1; - } - } else { - vrt_node_ = this_node_ - nprocs_rem_; - } + initializeVrtNode(); vt_debug_print( terse, allreduce, @@ -93,21 +85,15 @@ Rabenseifner::Rabenseifner(detail::StrongGroup group) num_steps_(static_cast(log2(num_nodes_))), nprocs_pof2_(1 << num_steps_), nprocs_rem_(num_nodes_ - nprocs_pof2_) { - std::string nodes_info; - for (auto& node : nodes_) { - nodes_info += fmt::format("{} ", node); - } auto const is_default_group = theGroup()->isGroupDefault(group_); auto const in_group = theGroup()->inGroup(group_); auto const is_part_of_allreduce = - (not is_default_group and in_group) or - is_default_group; + (not is_default_group and in_group) or is_default_group; vt_debug_print( terse, allreduce, - "Rabenseifner: is_default_group={} is_part_of_allreduce={} num_nodes_={} " - "Nodes:[{}]\n", - is_default_group, is_part_of_allreduce, num_nodes_, nodes_info); + "Rabenseifner: is_default_group={} is_part_of_allreduce={} num_nodes_={} \n", + is_default_group, is_part_of_allreduce, num_nodes_); if (not is_default_group and in_group) { auto it = std::find(nodes_.begin(), nodes_.end(), theContext()->getNode()); @@ -120,16 +106,8 @@ Rabenseifner::Rabenseifner(detail::StrongGroup group) // We collectively create this Reducer, so it's possible that not all Nodes are part of it if (is_part_of_allreduce) { is_even_ = this_node_ % 2 == 0; - is_part_of_adjustment_group_ = this_node_ < (2 * nprocs_rem_); - if (is_part_of_adjustment_group_) { - if (is_even_) { - vrt_node_ = this_node_ / 2; - } else { - vrt_node_ = -1; - } - } else { - vrt_node_ = this_node_ - nprocs_rem_; - } + + initializeVrtNode(); } } @@ -141,17 +119,21 @@ Rabenseifner::Rabenseifner(detail::StrongObjGroup objgroup) num_steps_(static_cast(log2(num_nodes_))), nprocs_pof2_(1 << num_steps_), nprocs_rem_(num_nodes_ - nprocs_pof2_) { + nodes_.resize(num_nodes_); for (NodeType i = 0; i < theContext()->getNumNodes(); ++i) { nodes_[i] = i; } + initializeVrtNode(); + vt_debug_print( terse, allreduce, "Rabenseifner: is_default_group={} is_part_of_allreduce={} num_nodes_={} \n", true, true, num_nodes_); +} - // We collectively create this Reducer, so it's possible that not all Nodes are part of it +void Rabenseifner::initializeVrtNode() { is_even_ = this_node_ % 2 == 0; is_part_of_adjustment_group_ = this_node_ < (2 * nprocs_rem_); if (is_part_of_adjustment_group_) { @@ -166,12 +148,10 @@ Rabenseifner::Rabenseifner(detail::StrongObjGroup objgroup) } Rabenseifner::~Rabenseifner() { - if (collection_proxy_ != u64empty) { - // StateHolder::clearAll(detail::StrongVrtProxy{collection_proxy_}); - - } else if (objgroup_proxy_ != u64empty) { + if (objgroup_proxy_ != u64empty) { StateHolder::clearAll(detail::StrongObjGroup{objgroup_proxy_}); - } else { + AllreduceHolder::remove(detail::StrongObjGroup{objgroup_proxy_}); + } else if(group_ != u64empty){ StateHolder::clearAll(detail::StrongGroup{group_}); } } diff --git a/src/vt/collective/reduce/allreduce/rabenseifner.h b/src/vt/collective/reduce/allreduce/rabenseifner.h index 98cde0b1bf..e49b09725e 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner.h +++ b/src/vt/collective/reduce/allreduce/rabenseifner.h @@ -84,6 +84,9 @@ struct Rabenseifner { * \param objgroup ObjGroupProxy */ Rabenseifner(detail::StrongObjGroup objgroup); + + void initializeVrtNode(); + ~Rabenseifner(); /** diff --git a/src/vt/collective/reduce/allreduce/recursive_doubling.cc b/src/vt/collective/reduce/allreduce/recursive_doubling.cc index ae6fcd5029..ef9e787b31 100644 --- a/src/vt/collective/reduce/allreduce/recursive_doubling.cc +++ b/src/vt/collective/reduce/allreduce/recursive_doubling.cc @@ -42,7 +42,6 @@ */ #include "vt/collective/reduce/allreduce/recursive_doubling.h" -#include "vt/collective/reduce/scoping/strong_types.h" #include "vt/group/group_manager.h" namespace vt::collective::reduce::allreduce { @@ -67,15 +66,7 @@ RecursiveDoubling::RecursiveDoubling( is_even_ = this_node_ % 2 == 0; is_part_of_adjustment_group_ = this_node_ < (2 * nprocs_rem_); - if (is_part_of_adjustment_group_) { - if (is_even_) { - vrt_node_ = this_node_ / 2; - } else { - vrt_node_ = -1; - } - } else { - vrt_node_ = this_node_ - nprocs_rem_; - } + initializeVrtNode(); vt_debug_print( terse, allreduce, @@ -83,8 +74,7 @@ RecursiveDoubling::RecursiveDoubling( print_ptr(this), proxy.get(), proxy_.getProxy(), local_num_elems_); } -RecursiveDoubling::RecursiveDoubling( - detail::StrongObjGroup objgroup) +RecursiveDoubling::RecursiveDoubling(detail::StrongObjGroup objgroup) : objgroup_proxy_(objgroup.get()), local_num_elems_(1), num_nodes_(theContext()->getNumNodes()), @@ -94,15 +84,12 @@ RecursiveDoubling::RecursiveDoubling( nprocs_pof2_(1 << num_steps_), nprocs_rem_(num_nodes_ - nprocs_pof2_), is_part_of_adjustment_group_(this_node_ < (2 * nprocs_rem_)) { - if (is_part_of_adjustment_group_) { - if (is_even_) { - vrt_node_ = this_node_ / 2; - } else { - vrt_node_ = -1; - } - } else { - vrt_node_ = this_node_ - nprocs_rem_; + nodes_.resize(num_nodes_); + for (NodeType i = 0; i < theContext()->getNumNodes(); ++i) { + nodes_[i] = i; } + + initializeVrtNode(); } RecursiveDoubling::RecursiveDoubling(detail::StrongGroup group) @@ -116,6 +103,10 @@ RecursiveDoubling::RecursiveDoubling(detail::StrongGroup group) nprocs_pof2_(1 << num_steps_), nprocs_rem_(num_nodes_ - nprocs_pof2_), is_part_of_adjustment_group_(this_node_ < (2 * nprocs_rem_)) { + initializeVrtNode(); +} + +void RecursiveDoubling::initializeVrtNode() { if (is_part_of_adjustment_group_) { if (is_even_) { vrt_node_ = this_node_ / 2; @@ -128,11 +119,10 @@ RecursiveDoubling::RecursiveDoubling(detail::StrongGroup group) } RecursiveDoubling::~RecursiveDoubling() { - if (collection_proxy_ != u64empty) { - // StateHolder::clearAll(detail::StrongVrtProxy{collection_proxy_}); - } else if (objgroup_proxy_ != u64empty) { + if (objgroup_proxy_ != u64empty) { StateHolder::clearAll(detail::StrongObjGroup{objgroup_proxy_}); - } else { + AllreduceHolder::remove(detail::StrongObjGroup{objgroup_proxy_}); + } else if (group_ != u64empty) { StateHolder::clearAll(detail::StrongGroup{group_}); } } diff --git a/src/vt/collective/reduce/allreduce/recursive_doubling.h b/src/vt/collective/reduce/allreduce/recursive_doubling.h index 3422a3c041..f5fb8591ad 100644 --- a/src/vt/collective/reduce/allreduce/recursive_doubling.h +++ b/src/vt/collective/reduce/allreduce/recursive_doubling.h @@ -84,6 +84,11 @@ struct RecursiveDoubling { */ RecursiveDoubling(detail::StrongGroup group); + /** + * \brief Initialize vrt_node_ variable + */ + void initializeVrtNode(); + ~RecursiveDoubling(); /** diff --git a/src/vt/objgroup/manager.h b/src/vt/objgroup/manager.h index 6fc3c82614..00f88f5a3b 100644 --- a/src/vt/objgroup/manager.h +++ b/src/vt/objgroup/manager.h @@ -519,10 +519,6 @@ struct ObjGroupManager : runtime::component::Component { std::unordered_map> pending_; /// Map of object groups' labels std::unordered_map labels_; - /// Recursive Doubling reducers - ReducerMapType reducers_recursive_doubling_; - /// Rabenseifner reducers - ReducerMapType reducers_rabenseifner_; }; }} /* end namespace vt::objgroup */ diff --git a/src/vt/objgroup/manager.impl.h b/src/vt/objgroup/manager.impl.h index 61fea8801e..241fedd198 100644 --- a/src/vt/objgroup/manager.impl.h +++ b/src/vt/objgroup/manager.impl.h @@ -41,6 +41,7 @@ //@HEADER */ +#include "vt/collective/reduce/allreduce/allreduce_holder.h" #if !defined INCLUDED_VT_OBJGROUP_MANAGER_IMPL_H #define INCLUDED_VT_OBJGROUP_MANAGER_IMPL_H @@ -357,44 +358,13 @@ ObjGroupManager::allreduce(ProxyType proxy, Args&&... data) { auto const this_node = vt::theContext()->getNode(); auto const strong_proxy = vt::collective::reduce::detail::StrongObjGroup{proxy.getProxy()}; - if constexpr (std::is_same_v) { - using Reducer = Rabenseifner; - - auto const id = StateHolder::getNextID(strong_proxy); - - auto grp_proxy = vt::theObjGroup()->makeCollective( - TypeToString(Reducer::type_), strong_proxy); - grp_proxy[this_node].get()->proxy_ = grp_proxy; - grp_proxy[this_node].get()->template setFinalHandler(cb, id); - grp_proxy[this_node].get()->template localReduce(id, std::forward(data)...); - // return PendingSendType{ - // theTerm()->getEpoch(), - // [&, this, args = std::make_tuple(std::forward(data)...)] { - // std::apply( - // [&, this](auto&&... unpackedArgs) { - // grp_proxy[this_node].template invoke<&Reducer::template localReduce>( - // id, std::forward(unpackedArgs)... - // ); - // }, - // std::move(args)); - // }}; - return PendingSendType{nullptr}; - } else if (std::is_same_v) { - using Reducer = RecursiveDoubling; - auto const id = StateHolder::getNextID(strong_proxy); - - auto grp_proxy = vt::theObjGroup()->makeCollective( - TypeToString(Reducer::type_), strong_proxy - ); - grp_proxy[this_node].get()->proxy_ = grp_proxy; - grp_proxy[this_node].get()->template setFinalHandler(cb, id); - grp_proxy[this_node].get()->template localReduce(id, std::forward(data)...); - // return PendingSendType{ - // theTerm()->getEpoch(), - // [=] { grp_proxy[this_node].template invoke<&Reducer::localReduce>(id); }}; - } else { - vtAssert(true, "Unknown allreduce algorithm type!"); - } + auto const id = StateHolder::getNextID(strong_proxy); + + auto grp_proxy = AllreduceHolder::getAllreducer(strong_proxy); + grp_proxy[this_node].get()->proxy_ = grp_proxy; + grp_proxy[this_node].get()->template setFinalHandler(cb, id); + grp_proxy[this_node].get()->template localReduce( + id, std::forward(data)...); // Silence nvcc warning return PendingSendType{nullptr};