From 5c33e20d507eeb200083fb6d9c9f63f6fe4c74d8 Mon Sep 17 00:00:00 2001 From: Jacob Domagala Date: Sat, 12 Oct 2024 01:31:52 +0200 Subject: [PATCH] #2281: Resolve issue with incorrect index generated by StateHolder::getNextID --- .../reduce/allreduce/allreduce_holder.h | 2 +- .../reduce/allreduce/rabenseifner.cc | 9 ----- .../reduce/allreduce/rabenseifner.h | 2 -- .../reduce/allreduce/recursive_doubling.cc | 9 ----- .../reduce/allreduce/recursive_doubling.h | 2 -- .../allreduce/recursive_doubling.impl.h | 2 -- .../reduce/allreduce/state_holder.cc | 34 ++++++++++++------- .../reduce/allreduce/state_holder.h | 11 +++--- .../reduce/allreduce/state_holder.impl.h | 3 +- src/vt/objgroup/manager.impl.h | 6 +++- .../vrt/collection/holders/typeless_holder.cc | 3 -- 11 files changed, 33 insertions(+), 50 deletions(-) diff --git a/src/vt/collective/reduce/allreduce/allreduce_holder.h b/src/vt/collective/reduce/allreduce/allreduce_holder.h index 9f3d0b2e4b..8b9a083227 100644 --- a/src/vt/collective/reduce/allreduce/allreduce_holder.h +++ b/src/vt/collective/reduce/allreduce/allreduce_holder.h @@ -47,7 +47,6 @@ #include "vt/configs/types/types_type.h" #include "vt/collective/reduce/allreduce/type.h" #include "vt/collective/reduce/scoping/strong_types.h" -#include "vt/configs/types/types_sentinels.h" #include "vt/objgroup/proxy/proxy_objgroup.h" #include @@ -56,6 +55,7 @@ namespace vt::collective::reduce::allreduce { struct Rabenseifner; struct RecursiveDoubling; + struct AllreduceHolder { using RabenseifnerProxy = ObjGroupProxyType; using RecursiveDoublingProxy = ObjGroupProxyType; diff --git a/src/vt/collective/reduce/allreduce/rabenseifner.cc b/src/vt/collective/reduce/allreduce/rabenseifner.cc index 6ce8e409f6..1ec410c4e7 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner.cc +++ b/src/vt/collective/reduce/allreduce/rabenseifner.cc @@ -150,13 +150,4 @@ void Rabenseifner::initializeVrtNode() { } } -Rabenseifner::~Rabenseifner() { - if (info_.first == ComponentT::ObjGroup) { - StateHolder::clearAll(detail::StrongObjGroup{info_.second}); - AllreduceHolder::remove(detail::StrongObjGroup{info_.second}); - } else if(info_.first == ComponentT::Group){ - StateHolder::clearAll(detail::StrongGroup{info_.second}); - } -} - } // namespace vt::collective::reduce::allreduce diff --git a/src/vt/collective/reduce/allreduce/rabenseifner.h b/src/vt/collective/reduce/allreduce/rabenseifner.h index 3837cc0262..9940f7fa4f 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner.h +++ b/src/vt/collective/reduce/allreduce/rabenseifner.h @@ -87,8 +87,6 @@ struct Rabenseifner { void initializeVrtNode(); - ~Rabenseifner(); - /** * \brief Set final handler that will be executed with allreduce result * diff --git a/src/vt/collective/reduce/allreduce/recursive_doubling.cc b/src/vt/collective/reduce/allreduce/recursive_doubling.cc index 845f048556..22099f86e7 100644 --- a/src/vt/collective/reduce/allreduce/recursive_doubling.cc +++ b/src/vt/collective/reduce/allreduce/recursive_doubling.cc @@ -117,13 +117,4 @@ void RecursiveDoubling::initializeVrtNode() { } } -RecursiveDoubling::~RecursiveDoubling() { -if (info_.first == ComponentT::ObjGroup) { - StateHolder::clearAll(detail::StrongObjGroup{info_.second}); - AllreduceHolder::remove(detail::StrongObjGroup{info_.second}); - } else if(info_.first == ComponentT::Group){ - StateHolder::clearAll(detail::StrongGroup{info_.second}); - } -} - } // namespace vt::collective::reduce::allreduce diff --git a/src/vt/collective/reduce/allreduce/recursive_doubling.h b/src/vt/collective/reduce/allreduce/recursive_doubling.h index d5d6187311..49f348c7be 100644 --- a/src/vt/collective/reduce/allreduce/recursive_doubling.h +++ b/src/vt/collective/reduce/allreduce/recursive_doubling.h @@ -89,8 +89,6 @@ struct RecursiveDoubling { */ void initializeVrtNode(); - ~RecursiveDoubling(); - /** * \brief Execute the final handler callback with the reduced result. * diff --git a/src/vt/collective/reduce/allreduce/recursive_doubling.impl.h b/src/vt/collective/reduce/allreduce/recursive_doubling.impl.h index 2bf4c3e7b5..307ef92b46 100644 --- a/src/vt/collective/reduce/allreduce/recursive_doubling.impl.h +++ b/src/vt/collective/reduce/allreduce/recursive_doubling.impl.h @@ -177,7 +177,6 @@ template class Op> template class Op> void RecursiveDoubling::adjustForPowerOfTwoHan( RecursiveDoublingMsg* msg) { - using DataType = DataHandler; auto& state = getState(info_, msg->id_); if (not state.value_assigned_) { if (not state.initialized_) { @@ -311,7 +310,6 @@ RecursiveDoubling::reduceIterHandler(RecursiveDoublingMsg* msg) { template class Op> void RecursiveDoubling::reduceIterHan(RecursiveDoublingMsg* msg) { - using DataType = DataHandler; auto& state = getState(info_, msg->id_); if (not state.value_assigned_) { diff --git a/src/vt/collective/reduce/allreduce/state_holder.cc b/src/vt/collective/reduce/allreduce/state_holder.cc index 85edbfa7be..0c3c81cb2e 100644 --- a/src/vt/collective/reduce/allreduce/state_holder.cc +++ b/src/vt/collective/reduce/allreduce/state_holder.cc @@ -52,6 +52,9 @@ size_t getNextIdImpl(StateHolder::StatesVec& states, size_t idx) { size_t id = u64empty; + vt_debug_print( + terse, allreduce, "getNextIdImpl idx={} size={} \n", idx, states.size()); + for (; idx < states.size(); ++idx) { auto& state = states.at(idx); if (not state or not state->active_) { @@ -64,28 +67,35 @@ getNextIdImpl(StateHolder::StatesVec& states, size_t idx) { id = states.size(); } + return id; } size_t StateHolder::getNextID(detail::StrongVrtProxy proxy) { - auto& states = active_coll_states_[proxy.get()]; + auto& [idx, states] = active_coll_states_[proxy.get()]; + + auto current_idx = getNextIdImpl(states, idx); + idx = current_idx + 1; - collection_idx_ = getNextIdImpl(states, collection_idx_); - return collection_idx_; + return current_idx; } size_t StateHolder::getNextID(detail::StrongObjGroup proxy) { - auto& states = active_obj_states_[proxy.get()]; + auto& [idx, states] = active_obj_states_[proxy.get()]; - objgroup_idx_ = getNextIdImpl(states, objgroup_idx_); - return objgroup_idx_; + auto current_idx = getNextIdImpl(states, idx); + idx = current_idx + 1; + + return current_idx; } size_t StateHolder::getNextID(detail::StrongGroup group) { - auto& states = active_grp_states_[group.get()]; + auto& [idx, states] = active_grp_states_[group.get()]; + + auto current_idx = getNextIdImpl(states, idx); - group_idx_ = getNextIdImpl(states, group_idx_); - return group_idx_; + idx = current_idx + 1; + return current_idx; } static inline void @@ -101,19 +111,19 @@ clearSingleImpl(StateHolder::StatesVec& states, size_t idx) { } void StateHolder::clearSingle(detail::StrongVrtProxy proxy, size_t idx) { - auto& states = active_coll_states_[proxy.get()]; + auto& [_, states] = active_coll_states_[proxy.get()]; clearSingleImpl(states, idx); } void StateHolder::clearSingle(detail::StrongObjGroup proxy, size_t idx) { - auto& states = active_obj_states_[proxy.get()]; + auto& [_, states] = active_obj_states_[proxy.get()]; clearSingleImpl(states, idx); } void StateHolder::clearSingle(detail::StrongGroup group, size_t idx) { - auto& states = active_grp_states_[group.get()]; + auto& [_, states] = active_grp_states_[group.get()]; clearSingleImpl(states, idx); } diff --git a/src/vt/collective/reduce/allreduce/state_holder.h b/src/vt/collective/reduce/allreduce/state_holder.h index 6983f8be30..93055403f3 100644 --- a/src/vt/collective/reduce/allreduce/state_holder.h +++ b/src/vt/collective/reduce/allreduce/state_holder.h @@ -57,6 +57,7 @@ namespace vt::collective::reduce::allreduce { struct StateHolder { using StatesVec = std::vector>; + using StatesInfo = std::pair; template < typename ReducerT, typename DataT, @@ -86,17 +87,13 @@ struct StateHolder { static void clearAll(detail::StrongGroup group); private: - static inline size_t collection_idx_ = 0; - static inline size_t objgroup_idx_ = 0; - static inline size_t group_idx_ = 0; - - static inline std::unordered_map + static inline std::unordered_map active_coll_states_ = {}; - static inline std::unordered_map + static inline std::unordered_map active_obj_states_ = {}; - static inline std::unordered_map active_grp_states_ = + static inline std::unordered_map active_grp_states_ = {}; }; diff --git a/src/vt/collective/reduce/allreduce/state_holder.impl.h b/src/vt/collective/reduce/allreduce/state_holder.impl.h index 16a51af8cf..f4bf85d0d4 100644 --- a/src/vt/collective/reduce/allreduce/state_holder.impl.h +++ b/src/vt/collective/reduce/allreduce/state_holder.impl.h @@ -41,7 +41,6 @@ //@HEADER */ -#include "vt/collective/reduce/allreduce/state.h" #if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_STATE_HOLDER_IMPL_H #define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_STATE_HOLDER_IMPL_H @@ -80,7 +79,7 @@ template < typename Scalar = typename DataHandler::Scalar, typename ProxyT, typename MapT> static auto& getStateImpl(ProxyT proxy, MapT& states_map, size_t idx) { - auto& states = states_map[proxy.get()]; + auto& [_, states] = states_map[proxy.get()]; auto const num_states = states.size(); if (idx >= num_states || num_states == 0) { diff --git a/src/vt/objgroup/manager.impl.h b/src/vt/objgroup/manager.impl.h index 2edfdf23d9..ea381dbb66 100644 --- a/src/vt/objgroup/manager.impl.h +++ b/src/vt/objgroup/manager.impl.h @@ -63,7 +63,7 @@ #include "vt/collective/reduce/allreduce/type.h" #include "vt/collective/reduce/allreduce/helpers.h" #include "vt/collective/reduce/scoping/strong_types.h" -#include "vt/collective/reduce/allreduce/state_holder.h" +#include "vt/collective/reduce/allreduce/allreduce_holder.h" #include "vt/pipe/pipe_manager.h" #include @@ -147,6 +147,10 @@ void ObjGroupManager::destroyCollective(ProxyType proxy) { if (label_iter != labels_.end()) { labels_.erase(label_iter); } + + vt::collective::reduce::allreduce::AllreduceHolder::remove( + vt::collective::reduce::detail::StrongObjGroup{proxy.getProxy()} + ); } template diff --git a/src/vt/vrt/collection/holders/typeless_holder.cc b/src/vt/vrt/collection/holders/typeless_holder.cc index 2e769e1652..82569ae122 100644 --- a/src/vt/vrt/collection/holders/typeless_holder.cc +++ b/src/vt/vrt/collection/holders/typeless_holder.cc @@ -76,9 +76,6 @@ void TypelessHolder::destroyCollection(VirtualProxyType const proxy) { } } - vt::collective::reduce::allreduce::StateHolder::clearAll( - vt::collective::reduce::detail::StrongVrtProxy{proxy}); - vt::collective::reduce::allreduce::AllreduceHolder::remove( vt::collective::reduce::detail::StrongVrtProxy{proxy} );