Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1086 lb refactor #1087

Merged
merged 8 commits into from
Sep 25, 2020
1 change: 0 additions & 1 deletion src/vt/vrt/collection/balance/baselb/baselb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
#include "vt/config.h"
#include "vt/vrt/collection/balance/baselb/baselb.h"
#include "vt/vrt/collection/balance/lb_comm.h"
#include "vt/vrt/collection/balance/lb_invoke/start_lb_msg.h"
#include "vt/vrt/collection/balance/read_lb.h"
#include "vt/vrt/collection/balance/lb_invoke/lb_manager.h"
#include "vt/vrt/collection/balance/node_stats.h"
Expand Down
1 change: 0 additions & 1 deletion src/vt/vrt/collection/balance/baselb/baselb.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@

#include "vt/config.h"
#include "vt/vrt/collection/balance/lb_common.h"
#include "vt/vrt/collection/balance/lb_invoke/start_lb_msg.h"
#include "vt/vrt/collection/balance/baselb/baselb_msgs.h"
#include "vt/vrt/collection/balance/stats_msg.h"
#include "vt/vrt/collection/balance/lb_comm.h"
Expand Down
8 changes: 3 additions & 5 deletions src/vt/vrt/collection/balance/elm_stats.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,17 @@ template <typename ColT>
before_ready, after_ready, ready
);

using MsgType = InvokeReduceMsg;

auto lb_man = theLBManager()->getProxy();

auto const single_node = theContext()->getNumNodes() == 1;
auto const lb = lb_man.get()->decideLBToRun(cur_phase);
bool const must_run_lb = lb != LBType::NoLB and not single_node;
auto const num_collections = theCollection()->numCollections<>();
auto const do_sync = msg->doSync();
auto nmsg = makeMessage<MsgType>(cur_phase,lb,msg->manual(),num_collections);
auto nmsg = makeMessage<InvokeMsg>(cur_phase,lb,msg->manual(),num_collections);

if (must_run_lb) {
auto cb = theCB()->makeBcast<LBManager,MsgType,&LBManager::sysLB<MsgType>>(lb_man);
auto cb = theCB()->makeBcast<LBManager,InvokeMsg,&LBManager::sysLB>(lb_man);
proxy.reduce(nmsg.get(),cb);
} else {

Expand All @@ -129,7 +127,7 @@ template <typename ColT>
theCollection()->elmFinishedLB(elm_proxy,cur_phase);
}

auto cb = theCB()->makeBcast<LBManager,MsgType,&LBManager::sysReleaseLB<MsgType>>(lb_man);
auto cb = theCB()->makeBcast<LBManager,InvokeMsg,&LBManager::sysReleaseLB>(lb_man);
proxy.reduce(nmsg.get(),cb);
}
}
Expand Down
6 changes: 0 additions & 6 deletions src/vt/vrt/collection/balance/greedylb/greedylb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,6 @@ void GreedyLB::loadOverBin(ObjBinType bin, ObjBinListType& bin_list) {
auto const threshold = this_threshold * avg_load;
auto const obj_id = bin_list.back();

if (load_over.find(bin) == load_over.end()) {
load_over_size += sizeof(std::size_t) * 4;
load_over_size += sizeof(ObjBinType);
}
load_over_size += sizeof(ObjIDType);

load_over[bin].push_back(obj_id);
bin_list.pop_back();

Expand Down
7 changes: 2 additions & 5 deletions src/vt/vrt/collection/balance/greedylb/greedylb.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
#include "vt/vrt/collection/balance/greedylb/greedylb_types.h"
#include "vt/vrt/collection/balance/greedylb/greedylb_constants.h"
#include "vt/vrt/collection/balance/greedylb/greedylb_msgs.h"
#include "vt/vrt/collection/balance/lb_invoke/start_lb_msg.h"
#include "vt/vrt/collection/balance/baselb/baselb.h"
#include "vt/timing/timing.h"

Expand Down Expand Up @@ -98,14 +97,12 @@ struct GreedyLB : BaseLB {
static objgroup::proxy::Proxy<GreedyLB> scatter_proxy;

private:
double greedy_max_threshold = 0.0f;
double greedy_threshold = 0.0f;
bool greedy_auto_threshold = true;
double this_threshold = 0.0f;
LoadType this_load_begin = 0.0f;
ObjSampleType load_over;
std::size_t load_over_size = 0;
objgroup::proxy::Proxy<GreedyLB> proxy = {};

// Parameters read from LB spec file
double max_threshold = 0.0f;
double min_threshold = 0.0f;
bool auto_threshold = true;
Expand Down
26 changes: 6 additions & 20 deletions src/vt/vrt/collection/balance/hierarchicallb/hierlb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,13 @@ void HierarchicalLB::loadStats() {
calcLoadOver(extract_strategy);

lbTreeUpSend(
bottom_parent, this_load, this_node, load_over, 1, load_over_size
bottom_parent, this_load, this_node, load_over, 1
);

if (children.size() == 0) {
auto const& total_size = sizeof(std::size_t) * 4;
ObjSampleType empty_obj{};
lbTreeUpSend(
parent, hierlb_no_load_sentinel, this_node, empty_obj, agg_node_size,
total_size
parent, hierlb_no_load_sentinel, this_node, empty_obj, agg_node_size
);
}
}
Expand All @@ -237,12 +235,6 @@ void HierarchicalLB::loadOverBin(ObjBinType bin, ObjBinListType& bin_list) {
auto const threshold = this_threshold * getAvgLoad();
auto const obj_id = bin_list.back();

if (load_over.find(bin) == load_over.end()) {
load_over_size += sizeof(std::size_t) * 4;
load_over_size += sizeof(ObjBinType);
}
load_over_size += sizeof(ObjIDType);

load_over[bin].push_back(obj_id);
bin_list.pop_back();

Expand Down Expand Up @@ -391,8 +383,7 @@ std::size_t HierarchicalLB::getSize(ObjSampleType const& sample) {

void HierarchicalLB::lbTreeUpSend(
NodeType const node, LoadType const child_load, NodeType const child,
ObjSampleType const& load, NodeType const child_size,
std::size_t const& load_size_approx
ObjSampleType const& load, NodeType const child_size
) {
auto msg = makeMessage<LBTreeUpMsg>(child_load,child,load,child_size);
proxy[node].template send<LBTreeUpMsg,&HierarchicalLB::lbTreeUpHandler>(msg);
Expand Down Expand Up @@ -683,31 +674,26 @@ void HierarchicalLB::distributeAmoungChildren() {
}
}

auto const& data_size = clearObj(given_objs);
clearObj(given_objs);
lbTreeUpSend(
parent, total_child_load, this_node, given_objs, total_size, data_size
parent, total_child_load, this_node, given_objs, total_size
);

given_objs.clear();
}

std::size_t HierarchicalLB::clearObj(ObjSampleType& objs) {
std::size_t total_size = 0;
void HierarchicalLB::clearObj(ObjSampleType& objs) {
std::vector<int> to_remove{};
for (auto&& bin : objs) {
if (bin.second.size() == 0) {
to_remove.push_back(bin.first);
}
total_size += bin.second.size() * sizeof(ObjIDType);
total_size += sizeof(ObjBinType);
total_size += sizeof(std::size_t) * 4;
}
for (auto&& r : to_remove) {
auto giter = objs.find(r);
vtAssert(giter != objs.end(), "Must exist");
objs.erase(giter);
}
return total_size;
}

void HierarchicalLB::runLB() {
Expand Down
7 changes: 2 additions & 5 deletions src/vt/vrt/collection/balance/hierarchicallb/hierlb.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
#include "vt/vrt/collection/balance/hierarchicallb/hierlb_msgs.h"
#include "vt/vrt/collection/balance/hierarchicallb/hierlb_strat.h"
#include "vt/vrt/collection/balance/baselb/baselb.h"
#include "vt/vrt/collection/balance/lb_invoke/start_lb_msg.h"
#include "vt/timing/timing.h"
#include "vt/objgroup/headers.h"

Expand Down Expand Up @@ -98,8 +97,7 @@ struct HierarchicalLB : BaseLB {
);
void lbTreeUpSend(
NodeType const node, LoadType const child_load, NodeType const child,
ObjSampleType const& load, NodeType const child_size,
std::size_t const& load_size_approx
ObjSampleType const& load, NodeType const child_size
);
void downTree(
NodeType const from, ObjSampleType excess, bool const final_child
Expand All @@ -111,7 +109,7 @@ struct HierarchicalLB : BaseLB {

void sendDownTree();
void distributeAmoungChildren();
std::size_t clearObj(ObjSampleType& objs);
void clearObj(ObjSampleType& objs);
HierLBChild* findMinChild();
void startMigrations();

Expand All @@ -130,7 +128,6 @@ struct HierarchicalLB : BaseLB {
ChildMapType children;
LoadType this_load_begin = 0.0f;
ObjSampleType load_over, given_objs, taken_objs;
std::size_t load_over_size = 0;
int64_t migrates_expected = 0, transfer_count = 0;
TransferType transfers;
objgroup::proxy::Proxy<HierarchicalLB> proxy = {};
Expand Down
10 changes: 3 additions & 7 deletions src/vt/vrt/collection/balance/lb_invoke/invoke_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,9 @@

namespace vt { namespace vrt { namespace collection { namespace balance {

template <typename MsgT>
struct InvokeBaseMsg : MsgT {
InvokeBaseMsg() = default;
InvokeBaseMsg(
struct InvokeMsg : collective::ReduceNoneMsg {
InvokeMsg() = default;
InvokeMsg(
PhaseType in_phase, LBType in_lb, bool manual, std::size_t in_num_colls = 1
) : phase_(in_phase), lb_(in_lb), manual_(manual),
num_collections_(in_num_colls)
Expand All @@ -67,9 +66,6 @@ struct InvokeBaseMsg : MsgT {
std::size_t num_collections_ = 0;
};

using InvokeMsg = InvokeBaseMsg<vt::Message>;
using InvokeReduceMsg = InvokeBaseMsg<collective::ReduceNoneMsg>;

}}}} /* end namespace vt::vrt::collection::balance */

#endif /*INCLUDED_VT_VRT_COLLECTION_BALANCE_LB_INVOKE_INVOKE_MSG_H*/
48 changes: 36 additions & 12 deletions src/vt/vrt/collection/balance/lb_invoke/lb_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
#include "vt/configs/arguments/app_config.h"
#include "vt/context/context.h"
#include "vt/vrt/collection/balance/lb_invoke/lb_manager.h"
#include "vt/vrt/collection/balance/lb_invoke/start_lb_msg.h"
#include "vt/vrt/collection/balance/read_lb.h"
#include "vt/vrt/collection/balance/lb_type.h"
#include "vt/vrt/collection/balance/node_stats.h"
Expand Down Expand Up @@ -144,16 +143,22 @@ void LBManager::setLoadModel(std::shared_ptr<LoadModel> model) {
}

template <typename LB>
void
LBManager::makeLB(MsgSharedPtr<StartLBMsg> msg) {
LBManager::LBProxyType
LBManager::makeLB() {
auto proxy = theObjGroup()->makeCollective<LB>();
auto strat = proxy.get();
strat->init(proxy);
auto base_proxy = proxy.template registerBaseCollective<lb::BaseLB>();
auto phase = msg->getPhase();

destroy_lb_ = [proxy]{ proxy.destroyCollective(); };

return base_proxy;
}

void
LBManager::runLB(LBProxyType base_proxy, PhaseType phase) {
lb::BaseLB* strat = base_proxy.get();

runInEpochCollective([=] {
model_->updateLoads(phase);
});
Expand Down Expand Up @@ -209,16 +214,15 @@ void LBManager::collectiveImpl(
);
}

auto msg = makeMessage<StartLBMsg>(phase);
switch (lb) {
case LBType::HierarchicalLB: makeLB<lb::HierarchicalLB>(msg); break;
case LBType::GreedyLB: makeLB<lb::GreedyLB>(msg); break;
case LBType::RotateLB: makeLB<lb::RotateLB>(msg); break;
case LBType::GossipLB: makeLB<lb::GossipLB>(msg); break;
case LBType::StatsMapLB: makeLB<lb::StatsMapLB>(msg); break;
case LBType::RandomLB: makeLB<lb::RandomLB>(msg); break;
case LBType::HierarchicalLB: lb_instances_["chosen"] = makeLB<lb::HierarchicalLB>(); break;
case LBType::GreedyLB: lb_instances_["chosen"] = makeLB<lb::GreedyLB>(); break;
case LBType::RotateLB: lb_instances_["chosen"] = makeLB<lb::RotateLB>(); break;
case LBType::GossipLB: lb_instances_["chosen"] = makeLB<lb::GossipLB>(); break;
case LBType::StatsMapLB: lb_instances_["chosen"] = makeLB<lb::StatsMapLB>(); break;
case LBType::RandomLB: lb_instances_["chosen"] = makeLB<lb::RandomLB>(); break;
# if vt_check_enabled(zoltan)
case LBType::ZoltanLB: makeLB<lb::ZoltanLB>(msg); break;
case LBType::ZoltanLB: lb_instances_["chosen"] = makeLB<lb::ZoltanLB>(); break;
# endif
case LBType::NoLB:
vtAssert(false, "LBType::NoLB is not a valid LB for collectiveImpl");
Expand All @@ -227,6 +231,10 @@ void LBManager::collectiveImpl(
vtAssert(false, "A valid LB must be passed to collectiveImpl");
break;
}

LBProxyType base_proxy = lb_instances_["chosen"];

runLB(base_proxy, phase);
}
}

Expand Down Expand Up @@ -302,6 +310,22 @@ void LBManager::releaseNow(PhaseType phase) {
num_invocations_ = num_release_ = 0;
}

void LBManager::sysLB(InvokeMsg* msg) {
vt_debug_print(lb, node, "sysLB\n");
printMemoryUsage(msg->phase_);
flushTraceNextPhase();
setTraceEnabledNextPhase(msg->phase_);
return collectiveImpl(msg->phase_, msg->lb_, msg->manual_, msg->num_collections_);
}

void LBManager::sysReleaseLB(InvokeMsg* msg) {
vt_debug_print(lb, node, "sysReleaseLB\n");
printMemoryUsage(msg->phase_);
flushTraceNextPhase();
setTraceEnabledNextPhase(msg->phase_);
return releaseImpl(msg->phase_, msg->num_collections_);
}

void LBManager::setTraceEnabledNextPhase(PhaseType phase) {
// Set if tracing is enabled for this next phase. Do this immediately before
// LB runs so LB is always instrumented as the beginning of the next phase
Expand Down
29 changes: 10 additions & 19 deletions src/vt/vrt/collection/balance/lb_invoke/lb_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@
#include "vt/config.h"
#include "vt/vrt/collection/balance/lb_type.h"
#include "vt/vrt/collection/balance/lb_invoke/invoke_msg.h"
#include "vt/vrt/collection/balance/lb_invoke/start_lb_msg.h"
#include "vt/configs/arguments/args.h"
#include "vt/runtime/component/component_pack.h"
#include "vt/objgroup/proxy/proxy_objgroup.h"
#include "vt/vrt/collection/balance/baselb/baselb.h"

#include <functional>

Expand All @@ -69,6 +70,7 @@ class LoadModel;
*/
struct LBManager : runtime::component::Component<LBManager> {
using ListenerFnType = std::function<void(PhaseType)>;
using LBProxyType = objgroup::proxy::Proxy<lb::BaseLB>;

/**
* \internal \brief System call to construct a \c LBManager
Expand Down Expand Up @@ -184,29 +186,15 @@ struct LBManager : runtime::component::Component<LBManager> {
*
* \param[in] msg the LB message
*/
template <typename MsgT>
void sysLB(MsgT* msg) {
vt_debug_print(lb, node, "sysLB\n");
printMemoryUsage(msg->phase_);
flushTraceNextPhase();
setTraceEnabledNextPhase(msg->phase_);
return collectiveImpl(msg->phase_, msg->lb_, msg->manual_, msg->num_collections_);
}
void sysLB(InvokeMsg* msg);

/**
* \internal \brief Tell the manager that a collection has hit \c nextPhase,
* choosing to skip load balancing
*
* \param[in] msg the LB message
*/
template <typename MsgT>
void sysReleaseLB(MsgT* msg) {
vt_debug_print(lb, node, "sysReleaseLB\n");
printMemoryUsage(msg->phase_);
flushTraceNextPhase();
setTraceEnabledNextPhase(msg->phase_);
return releaseImpl(msg->phase_, msg->num_collections_);
}
void sysReleaseLB(InvokeMsg* msg);

public:
/**
Expand Down Expand Up @@ -256,12 +244,14 @@ struct LBManager : runtime::component::Component<LBManager> {
/**
* \internal \brief Collectively construct a new load balancer
*
* \param[in] msg the start LB message
* \param[in] LB the type of strategy to instantiate
*
* \return objgroup proxy to the new load balancer
*/
template <typename LB>
void makeLB(MsgSharedPtr<StartLBMsg> msg);
LBProxyType makeLB();

void runLB(LBProxyType base_proxy, PhaseType phase);

private:
std::size_t num_invocations_ = 0;
Expand All @@ -274,6 +264,7 @@ struct LBManager : runtime::component::Component<LBManager> {
objgroup::proxy::Proxy<LBManager> proxy_;
std::shared_ptr<LoadModel> base_model_;
std::shared_ptr<LoadModel> model_;
std::unordered_map<std::string, LBProxyType> lb_instances_;
};

}}}} /* end namespace vt::vrt::collection::balance */
Expand Down
Loading