Skip to content

Commit

Permalink
#702: tests: add chaining test that includes reduce to test_term_chai…
Browse files Browse the repository at this point in the history
…ning.cc
  • Loading branch information
nmm0 committed Mar 2, 2020
1 parent 5da6a26 commit aa057bf
Showing 1 changed file with 116 additions and 0 deletions.
116 changes: 116 additions & 0 deletions tests/unit/termination/test_term_chaining.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ struct TestTermChaining : TestParallelHarness {
static vt::messaging::DependentSendChain chain;
static vt::EpochType epoch;

struct ChainReduceMsg : collective::ReduceMsg {
ChainReduceMsg(int in_num)
: num(in_num)
{}

int num = 0;
};

static void test_handler_reflector(TestMsg* msg) {
fmt::print("reflector run\n");

Expand Down Expand Up @@ -98,6 +106,29 @@ struct TestTermChaining : TestParallelHarness {
handler_count = 4;
}

static void test_handler_set(TestMsg* msg) {
handler_count = 1;
}

static void test_handler_reduce(ChainReduceMsg *msg) {
if (msg->isRoot()) {
EXPECT_EQ(theContext()->getNode(), 0);
EXPECT_EQ(handler_count, 1);
auto n = theContext()->getNumNodes();
EXPECT_EQ(msg->num, n * (n - 1)/2);
handler_count = 2;
}
}

static void test_handler_bcast(TestMsg* msg) {
if (theContext()->getNode() == 0) {
EXPECT_EQ(handler_count, 2);
} else {
EXPECT_EQ(handler_count, 12);
}
handler_count = 3;
}

static void start_chain() {
EpochType epoch1 = theTerm()->makeEpochRooted();
vt::theMsg()->pushEpoch(epoch1);
Expand All @@ -116,6 +147,58 @@ struct TestTermChaining : TestParallelHarness {
chain.done();
}

static void chain_reduce() {
auto node = theContext()->getNode();

if (0 == node) {
EpochType epoch1 = theTerm()->makeEpochRooted();
vt::theMsg()->pushEpoch(epoch1);
auto msg = makeSharedMessage<TestMsg>();
chain.add(
epoch1, theMsg()->sendMsg<TestMsg, test_handler_reflector>(1, msg));
vt::theMsg()->popEpoch(epoch1);
vt::theTerm()->finishedEpoch(epoch1);
}

EpochType epoch2 = theTerm()->makeEpochCollective();
vt::theMsg()->pushEpoch(epoch2);
auto msg2 = makeSharedMessage<ChainReduceMsg>(theContext()->getNode());
chain.add(epoch2, theCollective()->reduce<ChainReduceMsg, test_handler_reduce>(0, msg2));
vt::theMsg()->popEpoch(epoch2);
vt::theTerm()->finishedEpoch(epoch2);

// Broadcast from both nodes, bcast wont send to itself
EpochType epoch3 = theTerm()->makeEpochRooted();
vt::theMsg()->pushEpoch(epoch3);
auto msg3 = makeSharedMessage<TestMsg>();
chain.add(
epoch3, theMsg()->broadcastMsg<TestMsg, test_handler_bcast>(msg3));
vt::theMsg()->popEpoch(epoch3);
vt::theTerm()->finishedEpoch(epoch3);

chain.done();
}

static void chain_reduce_single() {
handler_count = 1;

EpochType epoch2 = theTerm()->makeEpochRooted();
vt::theMsg()->pushEpoch(epoch2);
auto msg2 = makeSharedMessage<ChainReduceMsg>(theContext()->getNode());
chain.add(epoch2, theCollective()->reduce<ChainReduceMsg, test_handler_reduce>(0, msg2));
vt::theMsg()->popEpoch(epoch2);
vt::theTerm()->finishedEpoch(epoch2);

EpochType epoch3 = theTerm()->makeEpochRooted();
vt::theMsg()->pushEpoch(epoch3);
auto msg3 = makeSharedMessage<TestMsg>();
chain.add(epoch3, theMsg()->broadcastMsg<TestMsg, test_handler_bcast>(msg3));
vt::theMsg()->popEpoch(epoch3);
vt::theTerm()->finishedEpoch(epoch3);

chain.done();
}

static void run_to_term() {
bool finished = false;

Expand All @@ -140,6 +223,8 @@ TEST_F(TestTermChaining, test_termination_chaining_1) {

epoch = theTerm()->makeEpochCollective();

handler_count = 0;

fmt::print("global collective epoch {:x}\n", epoch);

if (this_node == 0) {
Expand All @@ -161,4 +246,35 @@ TEST_F(TestTermChaining, test_termination_chaining_1) {
}
}

TEST_F(TestTermChaining, test_termination_chaining_collective_1) {
auto const& num_nodes = theContext()->getNumNodes();

if (num_nodes == 2) {

chain = vt::messaging::DependentSendChain{};
epoch = theTerm()->makeEpochCollective();

handler_count = 0;

theMsg()->pushEpoch(epoch);
chain_reduce();
theTerm()->finishedEpoch(epoch);
theMsg()->popEpoch(epoch);
run_to_term();
EXPECT_EQ(handler_count, 3);
} else if (num_nodes == 1) {
chain = vt::messaging::DependentSendChain{};
epoch = theTerm()->makeEpochCollective();

handler_count = 0;

theMsg()->pushEpoch(epoch);
chain_reduce_single();
theTerm()->finishedEpoch(epoch);
theMsg()->popEpoch(epoch);
run_to_term();
// EXPECT_EQ(handler_count, 3);
}
}

}}} // end namespace vt::tests::unit

0 comments on commit aa057bf

Please sign in to comment.