diff --git a/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc b/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc index e9926fe121..158dc6210d 100644 --- a/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc +++ b/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc @@ -475,7 +475,7 @@ void TemperedLB::runLB(TimeType total_load) { } void TemperedLB::clearDataStructures() { - underloaded_.clear(); + potential_recipients_.clear(); load_info_.clear(); is_overloaded_ = is_underloaded_ = false; } @@ -666,7 +666,7 @@ void TemperedLB::informAsync() { auto const this_node = theContext()->getNode(); if (canPropagate()) { - underloaded_.insert(this_node); + potential_recipients_.insert(this_node); } setup_done_ = false; @@ -693,7 +693,7 @@ void TemperedLB::informAsync() { vt_debug_print( terse, temperedlb, "TemperedLB::informAsync: trial={}, iter={}, known underloaded={}\n", - trial_, iter_, underloaded_.size() + trial_, iter_, potential_recipients_.size() ); } @@ -717,12 +717,12 @@ void TemperedLB::informSync() { auto const this_node = theContext()->getNode(); if (canPropagate()) { - underloaded_.insert(this_node); + potential_recipients_.insert(this_node); } auto propagate_this_round = canPropagate(); propagate_next_round_ = false; - new_underloaded_ = underloaded_; + new_potential_recipients_ = potential_recipients_; new_load_info_ = load_info_; setup_done_ = false; @@ -752,7 +752,7 @@ void TemperedLB::informSync() { propagate_this_round = propagate_next_round_; propagate_next_round_ = false; - underloaded_ = new_underloaded_; + potential_recipients_ = new_potential_recipients_; load_info_ = new_load_info_; } @@ -760,7 +760,7 @@ void TemperedLB::informSync() { vt_debug_print( terse, temperedlb, "TemperedLB::informSync: trial={}, iter={}, known underloaded={}\n", - trial_, iter_, underloaded_.size() + trial_, iter_, potential_recipients_.size() ); } @@ -791,7 +791,7 @@ void TemperedLB::propagateRound(uint8_t k_cur, bool sync, EpochType epoch) { gen_propagate_.seed(seed_()); } - auto& selected = underloaded_; + auto& selected = potential_recipients_; if (selected.find(this_node) == selected.end()) { selected.insert(this_node); } @@ -868,7 +868,7 @@ void TemperedLB::propagateIncomingAsync(LoadMsgAsync* msg) { load_info_[elm.first] = elm.second; if (isUnderloaded(elm.second)) { - underloaded_.insert(elm.first); + potential_recipients_.insert(elm.first); } } } @@ -902,7 +902,7 @@ void TemperedLB::propagateIncomingSync(LoadMsgSync* msg) { new_load_info_[elm.first] = elm.second; if (isUnderloaded(elm.second)) { - new_underloaded_.insert(elm.first); + new_potential_recipients_.insert(elm.first); } } } @@ -1204,7 +1204,7 @@ void TemperedLB::decide() { auto potential_recipients = getPotentialRecipients(); std::unordered_map migrate_objs; - if (potential_recipients.size() > 0) { + if (not potential_recipients.empty()) { std::vector ordered_obj_ids = orderObjects( obj_ordering_, cur_objs_, this_new_load_, target_max_load_ ); diff --git a/src/vt/vrt/collection/balance/temperedlb/temperedlb.h b/src/vt/vrt/collection/balance/temperedlb/temperedlb.h index e5e7c0fb7e..5af5b501fb 100644 --- a/src/vt/vrt/collection/balance/temperedlb/temperedlb.h +++ b/src/vt/vrt/collection/balance/temperedlb/temperedlb.h @@ -96,6 +96,9 @@ struct TemperedLB : BaseLB { void migrate(); void clearDataStructures(); + /** + * \brief Decides whether the rank can perform the migration + */ virtual bool canMigrate() const { return is_overloaded_; } /** * \brief Decides whether the rank can initiate information propagation stage @@ -132,6 +135,7 @@ struct TemperedLB : BaseLB { void setupDone(ReduceMsgType* msg); std::unordered_map load_info_ = {}; + std::unordered_map cur_objs_ = {}; private: uint16_t f_ = 0; @@ -175,9 +179,8 @@ struct TemperedLB : BaseLB { objgroup::proxy::Proxy proxy_ = {}; bool is_overloaded_ = false; bool is_underloaded_ = false; - std::unordered_set underloaded_ = {}; - std::unordered_set new_underloaded_ = {}; - std::unordered_map cur_objs_ = {}; + std::unordered_set potential_recipients_ = {}; + std::unordered_set new_potential_recipients_ = {}; LoadType this_new_load_ = 0.0; TimeType new_imbalance_ = 0.0; TimeType target_max_load_ = 0.0; diff --git a/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.cc b/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.cc index 1788faa181..128d77a6f6 100644 --- a/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.cc +++ b/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.cc @@ -48,6 +48,8 @@ #include "vt/vrt/collection/balance/model/load_model.h" #include "vt/vrt/collection/balance/model/weighted_communication_volume.h" +#include + namespace vt { namespace vrt { namespace collection { namespace lb { void TemperedWMin::init(objgroup::proxy::Proxy in_proxy) { @@ -132,4 +134,13 @@ TimeType TemperedWMin::getModeledValue(const elm::ElementIDStruct& obj) { return total_work_model_->getModeledLoad(obj, when); } +bool TemperedWMin::canMigrate() const { + auto const this_node = theContext()->getNode(); + auto const another_rank = std::find_if( + load_info_.begin(), load_info_.end(), + [this_node](auto const& elm) { return elm.first != this_node; } + ); + return (not cur_objs_.empty()) and (another_rank != load_info_.end()); +} + }}}} // namespace vt::vrt::collection::lb diff --git a/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.h b/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.h index 3da90fa4c3..107b6b4c77 100644 --- a/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.h +++ b/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.h @@ -66,6 +66,10 @@ struct TemperedWMin : TemperedLB { protected: TimeType getModeledValue(const elm::ElementIDStruct& obj) override; + /** + * Allow migration when there are objects to migrate and other ranks are known + */ + bool canMigrate() const override; /** * All ranks are allowed to initiate the information propagation stage */