Skip to content

Commit

Permalink
#2074: Disallow operating on non-consecutive phases
Browse files Browse the repository at this point in the history
  • Loading branch information
thearusable authored and cwschilly committed Sep 20, 2024
1 parent a08b77e commit 0623eb1
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 45 deletions.
50 changes: 24 additions & 26 deletions src/vt/vrt/collection/balance/lb_data_restart_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,34 +162,32 @@ void LBDataRestartReader::determinePhasesToMigrate() {

auto const this_node = theContext()->getNode();
runInEpochCollective("LBDataRestartReader::updateLocations", [&]{
PhaseType curr = 0, next;
for (;curr < num_phases_ - 1;) {
next = findNextPhase(curr);

local_changed_distro[curr] = *history_[curr] != *history_[next];
if (local_changed_distro[curr]) {
std::set<ElementIDStruct> departing, arriving;

std::set_difference(
history_[next]->begin(), history_[next]->end(),
history_[curr]->begin(), history_[curr]->end(),
std::inserter(arriving, arriving.begin())
);

std::set_difference(
history_[curr]->begin(), history_[curr]->end(),
history_[next]->begin(), history_[next]->end(),
std::inserter(departing, departing.begin())
);

for (auto&& d : departing) {
proxy_[d.getHomeNode()].send<DepartMsg, &LBDataRestartReader::departing>(this_node, next, d);
}
for (auto&& a : arriving) {
proxy_[a.getHomeNode()].send<ArriveMsg, &LBDataRestartReader::arriving>(this_node, next, a);
for (PhaseType curr = 0; curr < num_phases_ - 1; ++curr) {
if(history_.count(curr) && history_.count(curr + 1)) {
local_changed_distro[curr] = *history_[curr] != *history_[curr + 1];
if (local_changed_distro[curr]) {
std::set<ElementIDStruct> departing, arriving;

std::set_difference(
history_[curr + 1]->begin(), history_[curr + 1]->end(),
history_[curr]->begin(), history_[curr]->end(),
std::inserter(arriving, arriving.begin())
);

std::set_difference(
history_[curr]->begin(), history_[curr]->end(),
history_[curr + 1]->begin(), history_[curr + 1]->end(),
std::inserter(departing, departing.begin())
);

for (auto&& d : departing) {
proxy_[d.getHomeNode()].send<DepartMsg, &LBDataRestartReader::departing>(this_node, curr + 1, d);
}
for (auto&& a : arriving) {
proxy_[a.getHomeNode()].send<ArriveMsg, &LBDataRestartReader::arriving>(this_node, curr + 1, a);
}
}
}
curr = next;
}
});

Expand Down
15 changes: 1 addition & 14 deletions src/vt/vrt/collection/balance/lb_data_restart_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ struct LBDataRestartReader : runtime::component::Component<LBDataRestartReader>
*
* \param[in] phase the phase
*
* \return pointer to elements assigned to this node, guaranted to be not null
* \return pointer to elements assigned to this node, guaranteed to be not null
*/
std::shared_ptr<const std::set<ElementIDStruct>> getDistro(PhaseType phase) const {
auto iter = history_.find(phase);
Expand Down Expand Up @@ -173,19 +173,6 @@ struct LBDataRestartReader : runtime::component::Component<LBDataRestartReader>
}

private:
/**
* \brief Find the next specified phase or an identical one
*
* \param phase the current phase
*
* \return the next phase
*/
PhaseType findNextPhase(PhaseType phase) const {
auto iter = history_.upper_bound(phase);
vtAssert(iter != history_.end(), "Must have a valid phase");
return iter->first;
}

/**
* \brief Reduce distribution changes globally to find where migrations need
* to occur
Expand Down
6 changes: 1 addition & 5 deletions tests/unit/lb/test_offlinelb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,9 @@ struct SimCol : vt::Collection<SimCol, vt::Index1D> {

void sparseHandler(Msg* m){
auto const this_node = theContext()->getNode();
auto const num_nodes = theContext()->getNumNodes();
auto const next_node = (this_node + 1) % num_nodes;
vt_debug_print(terse, lb, "sparseHandler: idx={}: elm={}\n", getIndex(), getElmID());
if (m->iter == 0 or m->iter == 1 or m->iter == 2 or m->iter == 3 or m->iter == 4) {
if (m->iter == 0 or m->iter == 1 or m->iter == 2 or m->iter == 3 or m->iter == 4 or m->iter == 5 or m->iter == 6) {
EXPECT_EQ(getIndex().x() / 2, this_node);
} else if (m->iter == 5 or m->iter == 6) {
EXPECT_EQ(getIndex().x() / 2, next_node);
}
}
};
Expand Down

0 comments on commit 0623eb1

Please sign in to comment.