diff --git a/src/vt/vrt/collection/balance/baselb/baselb.h b/src/vt/vrt/collection/balance/baselb/baselb.h index 9c6cd27b63..ce0213043a 100644 --- a/src/vt/vrt/collection/balance/baselb/baselb.h +++ b/src/vt/vrt/collection/balance/baselb/baselb.h @@ -164,6 +164,16 @@ struct BaseLB { bool isCommAware() const { return comm_aware_; } void recvSharedEdges(CommMsg* msg); + /** + * \brief Get the estimated time needed for load balancing + * + * \return the estimated time + */ + double getCollectiveEpochCost() const { + // 100 ns + return 0.0000001; + } + protected: void getArgs(PhaseType phase); diff --git a/src/vt/vrt/collection/balance/greedylb/greedylb.cc b/src/vt/vrt/collection/balance/greedylb/greedylb.cc index c82b6f818b..e9ec43b4af 100644 --- a/src/vt/vrt/collection/balance/greedylb/greedylb.cc +++ b/src/vt/vrt/collection/balance/greedylb/greedylb.cc @@ -160,7 +160,8 @@ void GreedyLB::loadStats() { bool should_lb = false; this_load_begin = this_load; - if (avg_load > 0.0000000001) { + // Use an estimated load-balancing cost on average rank load to load-balance + if (avg_load > getCollectiveEpochCost()) { should_lb = I > greedy_tolerance; } diff --git a/src/vt/vrt/collection/balance/hierarchicallb/hierlb.cc b/src/vt/vrt/collection/balance/hierarchicallb/hierlb.cc index 9eab2a25fe..215694fa5b 100644 --- a/src/vt/vrt/collection/balance/hierarchicallb/hierlb.cc +++ b/src/vt/vrt/collection/balance/hierarchicallb/hierlb.cc @@ -263,7 +263,8 @@ void HierarchicalLB::loadStats() { bool should_lb = false; this_load_begin = this_load; - if (avg_load > 0.0000000001) { + // Use an estimated load-balancing cost on average rank load to load-balance + if (avg_load > getCollectiveEpochCost()) { should_lb = I > hierlb_tolerance; } diff --git a/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc b/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc index 6dba6e0ecf..f802187391 100644 --- a/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc +++ b/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc @@ -58,7 +58,6 @@ #include #include #include -#include namespace vt { namespace vrt { namespace collection { namespace lb { @@ -439,9 +438,6 @@ void TemperedLB::runLB(LoadType total_load) { auto const imb = stats.at(lb::Statistic::Rank_load_modeled).at( lb::StatisticQuantity::imb ); - auto const min = stats.at(lb::Statistic::Object_load_modeled).at( - lb::StatisticQuantity::min - ); auto const load = this_load; if (target_pole_) { @@ -453,8 +449,8 @@ void TemperedLB::runLB(LoadType total_load) { target_max_load_ = avg; } - // Use an minimal object load on average rank load to load-balance - if (avg > min / theContext()->getNumNodes()) { + // Use an estimated load-balancing cost on average rank load to load-balance + if (avg > getCollectiveEpochCost()) { should_lb = max > (run_temperedlb_tolerance + 1.0) * target_max_load_; }