Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#941 added subphase support to loadComm #1081

Merged
merged 15 commits into from
Oct 15, 2020
Merged
32 changes: 26 additions & 6 deletions src/vt/vrt/collection/balance/elm_stats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,31 +79,38 @@ void ElementStats::stopTime() {
);
}

void ElementStats::recvComm(
LBCommKey key, double bytes
) {
comm_.resize(cur_phase_ + 1);
comm_.at(cur_phase_)[key].receiveMsg(bytes);
subphase_comm_.resize(cur_phase_ + 1);
subphase_comm_.at(cur_phase_).resize(cur_subphase_ + 1);
subphase_comm_.at(cur_phase_).at(cur_subphase_)[key].receiveMsg(bytes);
}

void ElementStats::recvObjData(
ElementIDType pto, ElementIDType tto,
ElementIDType pfrom, ElementIDType tfrom, double bytes, bool bcast
) {
comm_.resize(cur_phase_ + 1);
LBCommKey key(LBCommKey::CollectionTag{}, pfrom, tfrom, pto, tto, bcast);
comm_.at(cur_phase_)[key].receiveMsg(bytes);
recvComm(key, bytes);
}

void ElementStats::recvFromNode(
ElementIDType pto, ElementIDType tto, NodeType from,
double bytes, bool bcast
) {
comm_.resize(cur_phase_ + 1);
LBCommKey key(LBCommKey::NodeToCollectionTag{}, from, pto, tto, bcast);
comm_.at(cur_phase_)[key].receiveMsg(bytes);
recvComm(key, bytes);
}

void ElementStats::recvToNode(
NodeType to, ElementIDType pfrom, ElementIDType tfrom,
double bytes, bool bcast
) {
comm_.resize(cur_phase_ + 1);
LBCommKey key(LBCommKey::CollectionToNodeTag{}, pfrom, tfrom, to, bcast);
comm_.at(cur_phase_)[key].receiveMsg(bytes);
recvComm(key, bytes);
}

void ElementStats::setModelWeight(TimeType const& time) {
Expand Down Expand Up @@ -195,6 +202,19 @@ ElementStats::getComm(PhaseType const& phase) {
return phase_comm;
}

std::vector<CommMapType> const& ElementStats::getSubphaseComm(PhaseType phase) {
subphase_comm_.resize(phase + 1);
auto const& subphase_comm = subphase_comm_[phase];

vt_debug_print(
lb, node,
"ElementStats: getSubphaseComm: comm size={}, phase={}\n",
subphase_comm.size(), phase
);

return subphase_comm;
}

void ElementStats::setSubPhase(SubphaseType subphase) {
vtAssert(subphase < no_subphase, "subphase must be less than sentinel");
cur_subphase_ = subphase;
Expand Down
3 changes: 3 additions & 0 deletions src/vt/vrt/collection/balance/elm_stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct ElementStats {
void startTime();
void stopTime();
void addTime(TimeType const& time);
void recvComm(LBCommKey key, double bytes);
void recvObjData(
ElementIDType to_perm, ElementIDType to_temp,
ElementIDType from_perm, ElementIDType from_temp, double bytes, bool bcast
Expand All @@ -88,6 +89,7 @@ struct ElementStats {
TimeType getLoad(PhaseType phase, SubphaseType subphase) const;

CommMapType const& getComm(PhaseType const& phase);
std::vector<CommMapType> const& getSubphaseComm(PhaseType phase);
void setSubPhase(SubphaseType subphase);
SubphaseType getSubPhase() const;

Expand All @@ -113,6 +115,7 @@ struct ElementStats {

SubphaseType cur_subphase_ = 0;
std::vector<std::vector<TimeType>> subphase_timings_ = {};
std::vector<std::vector<CommMapType>> subphase_comm_ = {};

static std::unordered_map<VirtualProxyType, SubphaseType> focused_subphase_;
};
Expand Down
3 changes: 2 additions & 1 deletion src/vt/vrt/collection/balance/elm_stats.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,11 @@ template <typename ColT>
auto const& total_load = stats.getLoad(cur_phase, getFocusedSubPhase(untyped_proxy));
auto const& subphase_loads = stats.subphase_timings_.at(cur_phase);
auto const& comm = stats.getComm(cur_phase);
auto const& subphase_comm = stats.getSubphaseComm(cur_phase);
auto const& idx = col->getIndex();
auto const& elm_proxy = proxy[idx];

theNodeStats()->addNodeStats(col, cur_phase, total_load, subphase_loads, comm);
theNodeStats()->addNodeStats(col, cur_phase, total_load, subphase_loads, comm, subphase_comm);

auto const before_ready = theCollection()->numReadyCollections();
theCollection()->makeCollectionReady(untyped_proxy);
Expand Down
14 changes: 13 additions & 1 deletion src/vt/vrt/collection/balance/node_stats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ std::unordered_map<PhaseType, CommMapType> const* NodeStats::getNodeComm() const
return &node_comm_;
}

std::unordered_map<PhaseType, std::unordered_map<SubphaseType, CommMapType>> const* NodeStats::getNodeSubphaseComm() const {
return &node_subphase_comm_;
}

void NodeStats::clearStats() {
NodeStats::node_comm_.clear();
NodeStats::node_data_.clear();
Expand Down Expand Up @@ -301,7 +305,8 @@ void NodeStats::outputStatsForPhase(PhaseType phase) {
ElementIDType NodeStats::addNodeStats(
Migratable* col_elm,
PhaseType const& phase, TimeType const& time,
std::vector<TimeType> const& subphase_time, CommMapType const& comm
std::vector<TimeType> const& subphase_time,
CommMapType const& comm, std::vector<CommMapType> const& subphase_comm
) {
// A new temp ID gets assigned when a object is migrated into a node

Expand Down Expand Up @@ -337,6 +342,13 @@ ElementIDType NodeStats::addNodeStats(
comm_data[c.first] += c.second;
}

auto &subphase_comm_data = node_subphase_comm_[phase];
for (SubphaseType i = 0; i < subphase_comm.size(); i++) {
for (auto& sp : subphase_comm[i]) {
subphase_comm_data[i][sp.first] += sp.second;
}
}

node_temp_to_perm_[temp_id] = perm_id;
node_perm_to_temp_[perm_id] = temp_id;

Expand Down
12 changes: 11 additions & 1 deletion src/vt/vrt/collection/balance/node_stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ struct NodeStats : runtime::component::Component<NodeStats> {
ElementIDType addNodeStats(
Migratable* col_elm,
PhaseType const& phase, TimeType const& time,
std::vector<TimeType> const& subphase_time, CommMapType const& comm
std::vector<TimeType> const& subphase_time,
CommMapType const& comm, std::vector<CommMapType> const& subphase_comm
);

/**
Expand Down Expand Up @@ -189,6 +190,13 @@ struct NodeStats : runtime::component::Component<NodeStats> {
*/
std::unordered_map<PhaseType, CommMapType> const* getNodeComm() const;

/**
* \internal \brief Get stored object comm subphase graph
*
* \return an observer pointer to the comm subphase graph
*/
std::unordered_map<PhaseType, std::unordered_map<SubphaseType, CommMapType>> const* getNodeSubphaseComm() const;

/**
* \internal \brief Test if this node has an object to migrate
*
Expand Down Expand Up @@ -268,6 +276,8 @@ struct NodeStats : runtime::component::Component<NodeStats> {
std::unordered_map<ElementIDType,VirtualProxyType> node_collection_lookup_;
/// Node communication graph for each local object
std::unordered_map<PhaseType, CommMapType> node_comm_;
/// Node communication graph for each subphase
std::unordered_map<PhaseType, std::unordered_map<SubphaseType, CommMapType>> node_subphase_comm_;
/// The current element ID
ElementIDType next_elm_;
/// The stats file name for outputting instrumentation
Expand Down