Skip to content

Commit

Permalink
Merge pull request #824 from DARMA-tasking/649-action-epoch
Browse files Browse the repository at this point in the history
Refactor code away from using low-level termination primitives, to higher-level routines
  • Loading branch information
lifflander authored Jun 10, 2020
2 parents 84bb174 + fde87f3 commit 26a049d
Show file tree
Hide file tree
Showing 15 changed files with 79 additions and 92 deletions.
25 changes: 6 additions & 19 deletions examples/collection/lb_iter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,21 +104,6 @@ void IterCol::iterWork(IterMsg* msg) {
}
}

template <typename Callable>
void executeInEpoch(Callable&& fn) {
auto this_node = vt::theContext()->getNode();
auto ep = vt::theTerm()->makeEpochCollective();
vt::theMsg()->pushEpoch(ep);
if (this_node == 0) {
fn();
}
vt::theMsg()->popEpoch(ep);
vt::theTerm()->finishedEpoch(ep);
bool done = false;
vt::theTerm()->addAction(ep, [&done]{ done = true; });
do vt::runScheduler(); while (!done);
}

int main(int argc, char** argv) {
vt::initialize(argc, argv);

Expand Down Expand Up @@ -153,17 +138,19 @@ int main(int argc, char** argv) {
for (int i = 0; i < num_iter; i++) {
auto cur_time = vt::timing::Timing::getCurrentTime();

executeInEpoch([=]{
proxy.broadcast<IterCol::IterMsg,&IterCol::iterWork>(10, i);
vt::runInEpochCollective([=]{
if (this_node == 0)
proxy.broadcast<IterCol::IterMsg,&IterCol::iterWork>(10, i);
});

auto total_time = vt::timing::Timing::getCurrentTime() - cur_time;
if (this_node == 0) {
fmt::print("iteration: iter={},time={}\n", i, total_time);
}

executeInEpoch([=]{
proxy.broadcast<IterCol::EmptyMsg,&IterCol::runLB>();
vt::runInEpochCollective([=]{
if (this_node == 0)
proxy.broadcast<IterCol::EmptyMsg,&IterCol::runLB>();
});

}
Expand Down
18 changes: 3 additions & 15 deletions examples/collection/migrate_collection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,6 @@ static void migrateToNext(ColMsg* msg, Hello* col) {
col->migrate(next_node);
}

template <typename Callable>
void executeInEpoch(Callable&& fn) {
auto ep = vt::theTerm()->makeEpochRooted();
vt::theMsg()->pushEpoch(ep);
fn();
vt::theMsg()->popEpoch(ep);
vt::theTerm()->finishedEpoch(ep);
bool done = false;
vt::theTerm()->addAction(ep, [&done]{ done = true; });
do vt::runScheduler(); while (!done);
}

int main(int argc, char** argv) {
vt::initialize(argc, argv);

Expand All @@ -117,17 +105,17 @@ int main(int argc, char** argv) {
auto range = vt::Index1D(num_elms);
auto proxy = vt::theCollection()->construct<Hello>(range, this_node);

executeInEpoch([=]{
vt::runInEpochRooted([=]{
auto msg = vt::makeMessage<ColMsg>(this_node);
proxy.broadcast<ColMsg, doWork>(msg.get());
});

executeInEpoch([=]{
vt::runInEpochRooted([=]{
auto msg = vt::makeMessage<ColMsg>(this_node);
proxy.broadcast<ColMsg, migrateToNext>(msg.get());
});

executeInEpoch([=]{
vt::runInEpochRooted([=]{
auto msg = vt::makeMessage<ColMsg>(this_node);
proxy.broadcast<ColMsg, doWork>(msg.get());
});
Expand Down
35 changes: 18 additions & 17 deletions examples/collection/reduce_integral.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,27 +221,28 @@ int main(int argc, char** argv) {
}
}

if (this_node == 0) {
//
// Create the interval decomposition into objects
//
using BaseIndexType = typename vt::Index1D::DenseIndexType;
auto range = vt::Index1D(static_cast<BaseIndexType>(num_objs));

auto proxy = vt::theCollection()->construct<Integration1D>(range);
proxy.broadcast<Integration1D::InitMsg,&Integration1D::compute>(
num_objs, numIntPerObject
);
}
vt::runInEpochCollective([=]{
if (this_node == 0) {
//
// Create the interval decomposition into objects
//
using BaseIndexType = typename vt::Index1D::DenseIndexType;
auto range = vt::Index1D(static_cast<BaseIndexType>(num_objs));

auto proxy = vt::theCollection()->construct<Integration1D>(range);
proxy.broadcast<Integration1D::InitMsg,&Integration1D::compute>
(
num_objs, numIntPerObject
);
}
});

// Add something like this to validate the reduction.
// Create the variable root_reduce_finished as a static variable,
// which is only checked on one node.
vt::theTerm()->addAction([]{
if (vt::theContext()->getNode() == reduce_root_node) {
vtAssertExpr(root_reduce_finished == true);
}
});
if (vt::theContext()->getNode() == reduce_root_node) {
vtAssertExpr(root_reduce_finished == true);
}

vt::finalize();

Expand Down
4 changes: 1 addition & 3 deletions src/vt/rdmahandle/sub_handle.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,7 @@ void SubHandle<T,E,IndexT>::afterLB() {
theMsg()->popEpoch(epoch);
theTerm()->finishedEpoch(epoch);

bool done = false;
theTerm()->addAction(epoch, [&done]{ done = true; });
theSched()->runSchedulerWhile([&done]{ return not done; });
runSchedulerThrough(epoch);
}

template <typename T, HandleEnum E, typename IndexT>
Expand Down
28 changes: 28 additions & 0 deletions src/vt/scheduler/scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,4 +330,32 @@ void runScheduler() {
theSched()->scheduler();
}

void runSchedulerThrough(EpochType epoch) {
// WARNING: This is to prevent global termination from spuriously
// thinking that the work done in this loop over the scheduler
// represents the entire work of the program, and thus leading to
// stuff being torn down
theTerm()->produce();
theSched()->runSchedulerWhile([=]{ return !theTerm()->isEpochTerminated(epoch); });
theTerm()->consume();
}

void runInEpochRooted(ActionType&& fn) {
auto ep = theTerm()->makeEpochRooted();
theMsg()->pushEpoch(ep);
fn();
theMsg()->popEpoch(ep);
theTerm()->finishedEpoch(ep);
runSchedulerThrough(ep);
}

void runInEpochCollective(ActionType&& fn) {
auto ep = theTerm()->makeEpochCollective();
theMsg()->pushEpoch(ep);
fn();
theMsg()->popEpoch(ep);
theTerm()->finishedEpoch(ep);
runSchedulerThrough(ep);
}

} //end namespace vt
4 changes: 4 additions & 0 deletions src/vt/scheduler/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ struct Scheduler : runtime::component::Component<Scheduler> {
namespace vt {

void runScheduler();
void runSchedulerThrough(EpochType epoch);

void runInEpochRooted(ActionType&& fn);
void runInEpochCollective(ActionType&& fn);

extern sched::Scheduler* theSched();

Expand Down
4 changes: 3 additions & 1 deletion src/vt/termination/termination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,9 @@ TermStatusEnum TerminationDetector::testEpochTerminated(EpochType epoch) {
TermStatusEnum status = TermStatusEnum::Pending;
auto const& is_rooted_epoch = epoch::EpochManip::isRooted(epoch);

if (is_rooted_epoch) {
if (getWindow(epoch)->isTerminated(epoch)) {
status = TermStatusEnum::Terminated;
} else if (is_rooted_epoch) {
auto const& this_node = theContext()->getNode();
auto const& root = epoch::EpochManip::node(epoch);
if (root == this_node) {
Expand Down
3 changes: 2 additions & 1 deletion src/vt/termination/termination.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ struct TerminationDetector :
public:
// TermTerminated interface
TermStatusEnum testEpochTerminated(EpochType epoch) override;
// Might return (conservatively) false if the epoch is non-local
// Might return (conservatively) false for some time if the epoch is
// non-local, but will eventually return true
bool isEpochTerminated(EpochType epoch);

public:
Expand Down
3 changes: 3 additions & 0 deletions src/vt/vrt/collection/balance/baselb/baselb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ EpochType BaseLB::startMigrationCollective() {
migration_epoch_ = theTerm()->makeEpochCollective("LB migration");
theTerm()->addAction(migration_epoch_, [this]{ this->migrationDone(); });
theMsg()->pushEpoch(migration_epoch_);
during_migration_ = true;
return migration_epoch_;
}

Expand All @@ -210,6 +211,7 @@ void BaseLB::finishMigrationCollective() {

theMsg()->popEpoch(migration_epoch_);
theTerm()->finishedEpoch(migration_epoch_);
during_migration_ = false;
}

void BaseLB::transferSend(
Expand All @@ -232,6 +234,7 @@ void BaseLB::transferMigrations(TransferMsg<TransferVecType>* msg) {
}

void BaseLB::migrateObjectTo(ObjIDType const obj_id, NodeType const to) {
vtAssert(during_migration_, "migrateObjectTo should be called between startMigrationCollective and finishMigrationCollective");
auto from = objGetNode(obj_id);
if (from != to) {
bool has_object = theProcStats()->hasObjectToMigrate(obj_id);
Expand Down
1 change: 1 addition & 0 deletions src/vt/vrt/collection/balance/baselb/baselb.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ struct BaseLB {
ElementLoadType const* load_data = nullptr;
ElementCommType const* comm_data = nullptr;
StatisticMapType stats = {};
bool during_migration_ = false;
EpochType migration_epoch_ = no_epoch;
TransferType off_node_migrate_ = {};
objgroup::proxy::Proxy<BaseLB> proxy_ = {};
Expand Down
8 changes: 2 additions & 6 deletions src/vt/vrt/collection/balance/gossiplb/gossiplb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,7 @@ void GossipLB::inform() {

theSched()->runSchedulerWhile([this]{ return not setup_done_; });

bool inform_done = false;
auto propagate_epoch = theTerm()->makeEpochCollective("GossipLB: inform");
theTerm()->addAction(propagate_epoch, [&inform_done] { inform_done = true; });

// Underloaded start the round
if (is_underloaded_) {
Expand All @@ -199,7 +197,7 @@ void GossipLB::inform() {

theTerm()->finishedEpoch(propagate_epoch);

theSched()->runSchedulerWhile([&inform_done]{ return not inform_done; });
vt::runSchedulerThrough(propagate_epoch);

debug_print(
gossiplb, node,
Expand Down Expand Up @@ -378,9 +376,7 @@ GossipLB::selectObject(
void GossipLB::decide() {
double const avg = stats.at(lb::Statistic::P_l).at(lb::StatisticQuantity::avg);

bool decide_done = false;
auto lazy_epoch = theTerm()->makeEpochCollective("GossipLB: decide");
theTerm()->addAction(lazy_epoch, [&decide_done] { decide_done = true; });

if (is_overloaded_) {
std::vector<NodeType> under = makeUnderloaded();
Expand Down Expand Up @@ -466,7 +462,7 @@ void GossipLB::decide() {

theTerm()->finishedEpoch(lazy_epoch);

theSched()->runSchedulerWhile([&decide_done]{ return not decide_done; });
vt::runSchedulerThrough(lazy_epoch);
}

void GossipLB::thunkMigrations() {
Expand Down
4 changes: 1 addition & 3 deletions src/vt/vrt/collection/balance/greedylb/greedylb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ void GreedyLB::recvObjsDirect(GreedyLBTypes::ObjIDType* objs) {
"recvObjsDirect: num_recs={}\n", num_recs
);

auto epoch = startMigrationCollective();
theMsg()->pushEpoch(epoch);
startMigrationCollective();

for (decltype(+num_recs) i = 0; i < num_recs; i++) {
auto const to_node = objGetNode(recs[i]);
Expand All @@ -243,7 +242,6 @@ void GreedyLB::recvObjsDirect(GreedyLBTypes::ObjIDType* objs) {
migrateObjectTo(new_obj_id, to_node);
}

theMsg()->popEpoch(epoch);
finishMigrationCollective();
}

Expand Down
14 changes: 2 additions & 12 deletions tests/unit/termination/test_term_chaining.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,6 @@ struct TestTermChaining : TestParallelHarness {

chain.done();
}

static void run_to_term() {
bool finished = false;

theTerm()->addAction(epoch, [&finished]{ finished = true; });

while (!finished) {
runScheduler();
}
}
};

/*static*/ int32_t TestTermChaining::handler_count = 0;
Expand All @@ -148,15 +138,15 @@ TEST_F(TestTermChaining, test_termination_chaining_1) {
theTerm()->finishedEpoch(epoch);
theMsg()->popEpoch(epoch);
fmt::print("before run 1\n");
run_to_term();
vt::runSchedulerThrough(epoch);
fmt::print("after run 1\n");

EXPECT_EQ(handler_count, 4);
} else {
theMsg()->pushEpoch(epoch);
theTerm()->finishedEpoch(epoch);
theMsg()->popEpoch(epoch);
run_to_term();
vt::runSchedulerThrough(epoch);
EXPECT_EQ(handler_count, 13);
}
}
Expand Down
14 changes: 4 additions & 10 deletions tests/unit/termination/test_term_cleanup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ TEST_F(TestTermCleanup, test_termination_cleanup_1) {

for (int i = 0; i < num_epochs; i++) {
EpochType const epoch = theTerm()->makeEpochCollective();
bool done = false;
//fmt::print("global collective epoch {:x}\n", epoch);

NodeType const next = this_node + 1 < num_nodes ? this_node + 1 : 0;
Expand All @@ -86,8 +85,7 @@ TEST_F(TestTermCleanup, test_termination_cleanup_1) {
theMsg()->sendMsg<TestMsgType, handler>(next, msg.get());

theTerm()->finishedEpoch(epoch);
theTerm()->addAction(epoch, [&]{ done = true; });
do vt::runScheduler(); while (not done);
vt::runSchedulerThrough(epoch);

EXPECT_LT(theTerm()->getEpochState().size(), std::size_t{2});
EXPECT_EQ(theTerm()->getEpochWaitSet().size(), std::size_t{0});
Expand Down Expand Up @@ -123,9 +121,6 @@ TEST_F(TestTermCleanup, test_termination_cleanup_2) {
EpochType const wave_epoch = theTerm()->makeEpochRootedWave(
term::SuccessorEpochCapture{no_epoch}
);
bool coll_done = false;
bool root_done = false;
bool wave_done = false;
//fmt::print("global collective epoch {:x}\n", epoch);

NodeType const next = this_node + 1 < num_nodes ? this_node + 1 : 0;
Expand All @@ -149,10 +144,9 @@ TEST_F(TestTermCleanup, test_termination_cleanup_2) {
theTerm()->finishedEpoch(coll_epoch);
theTerm()->finishedEpoch(root_epoch);
theTerm()->finishedEpoch(wave_epoch);
theTerm()->addAction(coll_epoch, [&]{ coll_done = true; });
theTerm()->addAction(root_epoch, [&]{ root_done = true; });
theTerm()->addAction(wave_epoch, [&]{ wave_done = true; });
do vt::runScheduler(); while (not coll_done or not root_done or not wave_done);
vt::runSchedulerThrough(coll_epoch);
vt::runSchedulerThrough(root_epoch);
vt::runSchedulerThrough(wave_epoch);
}

while (not vt::rt->isTerminated() or not vt::theSched()->isIdle()) {
Expand Down
6 changes: 1 addition & 5 deletions tests/unit/termination/test_term_dep_send_chain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,15 +414,11 @@ struct MyObjGroup {
}

void finishUpdate() {
bool vt_working = true;
chains_->phaseDone();
vt::theMsg()->popEpoch(epoch_);
vt::theTerm()->addAction(epoch_, [&vt_working] { vt_working = false; });
vt::theTerm()->finishedEpoch(epoch_);

while (vt_working) {
vt::runScheduler();
}
vt::runSchedulerThrough(epoch_);

started_ = false;
}
Expand Down

0 comments on commit 26a049d

Please sign in to comment.