From aa057bf8599e23d9b260aa49ab04a4d44245113d Mon Sep 17 00:00:00 2001 From: Nicolas Morales Date: Mon, 2 Mar 2020 12:12:06 -0800 Subject: [PATCH] #702: tests: add chaining test that includes reduce to test_term_chaining.cc --- tests/unit/termination/test_term_chaining.cc | 116 +++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/tests/unit/termination/test_term_chaining.cc b/tests/unit/termination/test_term_chaining.cc index 360d122440..90bcc08493 100644 --- a/tests/unit/termination/test_term_chaining.cc +++ b/tests/unit/termination/test_term_chaining.cc @@ -63,6 +63,14 @@ struct TestTermChaining : TestParallelHarness { static vt::messaging::DependentSendChain chain; static vt::EpochType epoch; + struct ChainReduceMsg : collective::ReduceMsg { + ChainReduceMsg(int in_num) + : num(in_num) + {} + + int num = 0; + }; + static void test_handler_reflector(TestMsg* msg) { fmt::print("reflector run\n"); @@ -98,6 +106,29 @@ struct TestTermChaining : TestParallelHarness { handler_count = 4; } + static void test_handler_set(TestMsg* msg) { + handler_count = 1; + } + + static void test_handler_reduce(ChainReduceMsg *msg) { + if (msg->isRoot()) { + EXPECT_EQ(theContext()->getNode(), 0); + EXPECT_EQ(handler_count, 1); + auto n = theContext()->getNumNodes(); + EXPECT_EQ(msg->num, n * (n - 1)/2); + handler_count = 2; + } + } + + static void test_handler_bcast(TestMsg* msg) { + if (theContext()->getNode() == 0) { + EXPECT_EQ(handler_count, 2); + } else { + EXPECT_EQ(handler_count, 12); + } + handler_count = 3; + } + static void start_chain() { EpochType epoch1 = theTerm()->makeEpochRooted(); vt::theMsg()->pushEpoch(epoch1); @@ -116,6 +147,58 @@ struct TestTermChaining : TestParallelHarness { chain.done(); } + static void chain_reduce() { + auto node = theContext()->getNode(); + + if (0 == node) { + EpochType epoch1 = theTerm()->makeEpochRooted(); + vt::theMsg()->pushEpoch(epoch1); + auto msg = makeSharedMessage(); + chain.add( + epoch1, theMsg()->sendMsg(1, msg)); + vt::theMsg()->popEpoch(epoch1); + vt::theTerm()->finishedEpoch(epoch1); + } + + EpochType epoch2 = theTerm()->makeEpochCollective(); + vt::theMsg()->pushEpoch(epoch2); + auto msg2 = makeSharedMessage(theContext()->getNode()); + chain.add(epoch2, theCollective()->reduce(0, msg2)); + vt::theMsg()->popEpoch(epoch2); + vt::theTerm()->finishedEpoch(epoch2); + + // Broadcast from both nodes, bcast wont send to itself + EpochType epoch3 = theTerm()->makeEpochRooted(); + vt::theMsg()->pushEpoch(epoch3); + auto msg3 = makeSharedMessage(); + chain.add( + epoch3, theMsg()->broadcastMsg(msg3)); + vt::theMsg()->popEpoch(epoch3); + vt::theTerm()->finishedEpoch(epoch3); + + chain.done(); + } + + static void chain_reduce_single() { + handler_count = 1; + + EpochType epoch2 = theTerm()->makeEpochRooted(); + vt::theMsg()->pushEpoch(epoch2); + auto msg2 = makeSharedMessage(theContext()->getNode()); + chain.add(epoch2, theCollective()->reduce(0, msg2)); + vt::theMsg()->popEpoch(epoch2); + vt::theTerm()->finishedEpoch(epoch2); + + EpochType epoch3 = theTerm()->makeEpochRooted(); + vt::theMsg()->pushEpoch(epoch3); + auto msg3 = makeSharedMessage(); + chain.add(epoch3, theMsg()->broadcastMsg(msg3)); + vt::theMsg()->popEpoch(epoch3); + vt::theTerm()->finishedEpoch(epoch3); + + chain.done(); + } + static void run_to_term() { bool finished = false; @@ -140,6 +223,8 @@ TEST_F(TestTermChaining, test_termination_chaining_1) { epoch = theTerm()->makeEpochCollective(); + handler_count = 0; + fmt::print("global collective epoch {:x}\n", epoch); if (this_node == 0) { @@ -161,4 +246,35 @@ TEST_F(TestTermChaining, test_termination_chaining_1) { } } +TEST_F(TestTermChaining, test_termination_chaining_collective_1) { + auto const& num_nodes = theContext()->getNumNodes(); + + if (num_nodes == 2) { + + chain = vt::messaging::DependentSendChain{}; + epoch = theTerm()->makeEpochCollective(); + + handler_count = 0; + + theMsg()->pushEpoch(epoch); + chain_reduce(); + theTerm()->finishedEpoch(epoch); + theMsg()->popEpoch(epoch); + run_to_term(); + EXPECT_EQ(handler_count, 3); + } else if (num_nodes == 1) { + chain = vt::messaging::DependentSendChain{}; + epoch = theTerm()->makeEpochCollective(); + + handler_count = 0; + + theMsg()->pushEpoch(epoch); + chain_reduce_single(); + theTerm()->finishedEpoch(epoch); + theMsg()->popEpoch(epoch); + run_to_term(); + // EXPECT_EQ(handler_count, 3); + } +} + }}} // end namespace vt::tests::unit