diff --git a/src/vt/collective/reduce/allreduce/rabenseifner.h b/src/vt/collective/reduce/allreduce/rabenseifner.h index 2b6832d440..454d175e63 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner.h +++ b/src/vt/collective/reduce/allreduce/rabenseifner.h @@ -41,7 +41,6 @@ //@HEADER */ - #if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_H #define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_H diff --git a/src/vt/collective/reduce/allreduce/rabenseifner_msg.h b/src/vt/collective/reduce/allreduce/rabenseifner_msg.h index a5e53aa692..66cd43bb39 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner_msg.h +++ b/src/vt/collective/reduce/allreduce/rabenseifner_msg.h @@ -41,7 +41,6 @@ //@HEADER */ - #if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_MSG_H #define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_MSG_H #include "vt/config.h" diff --git a/src/vt/collective/reduce/allreduce/recursive_doubling.impl.h b/src/vt/collective/reduce/allreduce/recursive_doubling.impl.h index accd39b6d6..ea3174b993 100644 --- a/src/vt/collective/reduce/allreduce/recursive_doubling.impl.h +++ b/src/vt/collective/reduce/allreduce/recursive_doubling.impl.h @@ -41,7 +41,6 @@ //@HEADER */ - #if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RECURSIVE_DOUBLING_IMPL_H #define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RECURSIVE_DOUBLING_IMPL_H diff --git a/src/vt/group/group_manager.impl.h b/src/vt/group/group_manager.impl.h index 9037e74368..96ede85a0e 100644 --- a/src/vt/group/group_manager.impl.h +++ b/src/vt/group/group_manager.impl.h @@ -41,7 +41,6 @@ //@HEADER */ -#include "vt/objgroup/proxy/proxy_objgroup.h" #if !defined INCLUDED_VT_GROUP_GROUP_MANAGER_IMPL_H #define INCLUDED_VT_GROUP_GROUP_MANAGER_IMPL_H diff --git a/src/vt/objgroup/manager.impl.h b/src/vt/objgroup/manager.impl.h index 8aa6d8d504..ec55effd92 100644 --- a/src/vt/objgroup/manager.impl.h +++ b/src/vt/objgroup/manager.impl.h @@ -41,8 +41,6 @@ //@HEADER */ -#include "vt/group/group_manager.h" -#include "vt/utils/fntraits/fntraits.h" #if !defined INCLUDED_VT_OBJGROUP_MANAGER_IMPL_H #define INCLUDED_VT_OBJGROUP_MANAGER_IMPL_H diff --git a/src/vt/pipe/pipe_manager.h b/src/vt/pipe/pipe_manager.h index 1fd21f1b59..bdb46e4f76 100644 --- a/src/vt/pipe/pipe_manager.h +++ b/src/vt/pipe/pipe_manager.h @@ -135,6 +135,16 @@ struct PipeManager template auto makeBcast(ProxyT proxy); + /** + * \brief Make collective broadcast callback to collection + * + * \param[in] proxy the proxy to target + * + * \return a callback + */ + template + auto makeBcastCollective(ProxyT proxy); + /** * \brief Make callback to a function (including lambdas) with a context * pointer to any object on this node. diff --git a/src/vt/pipe/pipe_manager.impl.h b/src/vt/pipe/pipe_manager.impl.h index c365726df1..5cd51508a6 100644 --- a/src/vt/pipe/pipe_manager.impl.h +++ b/src/vt/pipe/pipe_manager.impl.h @@ -154,6 +154,11 @@ auto PipeManager::makeBcast(ProxyT proxy) { return makeCallbackProxy(proxy); } +template +auto PipeManager::makeBcastCollective(ProxyT proxy) { + return makeCallbackBcastCollectiveProxy(proxy); +} + template * f> Callback PipeManager::makeBcast() { return makeCallbackSingle(); diff --git a/src/vt/pipe/pipe_manager_tl.h b/src/vt/pipe/pipe_manager_tl.h index 8790ca753e..a9f3e27b3c 100644 --- a/src/vt/pipe/pipe_manager_tl.h +++ b/src/vt/pipe/pipe_manager_tl.h @@ -114,7 +114,7 @@ struct PipeManagerTL : virtual PipeManagerBase { auto makeCallbackProxy(ProxyT proxy); template - auto makeCallbackBcastProxy(ProxyT proxy); + auto makeCallbackBcastCollectiveProxy(ProxyT proxy); // Multi-staged callback template diff --git a/src/vt/pipe/pipe_manager_tl.impl.h b/src/vt/pipe/pipe_manager_tl.impl.h index d70f3623d1..887d1e16dd 100644 --- a/src/vt/pipe/pipe_manager_tl.impl.h +++ b/src/vt/pipe/pipe_manager_tl.impl.h @@ -142,7 +142,7 @@ template using hasIdx_t = typename U::IndexType; template -auto PipeManagerTL::makeCallbackBcastProxy(ProxyT proxy) { +auto PipeManagerTL::makeCallbackBcastCollectiveProxy(ProxyT proxy) { bool const persist = true; bool const send_back = false; bool const dispatch = true; diff --git a/src/vt/vrt/collection/manager.h b/src/vt/vrt/collection/manager.h index 2062584fda..6bf1fdd48e 100644 --- a/src/vt/vrt/collection/manager.h +++ b/src/vt/vrt/collection/manager.h @@ -41,8 +41,6 @@ //@HEADER */ -#include "vt/configs/types/types_type.h" -#include #if !defined INCLUDED_VT_VRT_COLLECTION_MANAGER_H #define INCLUDED_VT_VRT_COLLECTION_MANAGER_H diff --git a/src/vt/vrt/collection/manager.impl.h b/src/vt/vrt/collection/manager.impl.h index d007e4338d..a343f565dc 100644 --- a/src/vt/vrt/collection/manager.impl.h +++ b/src/vt/vrt/collection/manager.impl.h @@ -41,9 +41,6 @@ //@HEADER */ -#include "vt/collective/reduce/scoping/strong_types.h" -#include "vt/messaging/message/smart_ptr.h" -#include "vt/vrt/collection/manager.fwd.h" #if !defined INCLUDED_VT_VRT_COLLECTION_MANAGER_IMPL_H #define INCLUDED_VT_VRT_COLLECTION_MANAGER_IMPL_H @@ -936,7 +933,7 @@ messaging::PendingSend CollectionManager::reduceLocal( auto* obj = obj_proxy[theContext()->getNode()].get(); obj->proxy_ = obj_proxy; - auto cb = vt::theCB()->makeCallbackBcastProxy(proxy); + auto cb = vt::theCB()->makeCallbackBcastCollectiveProxy(proxy); obj->setFinalHandler(cb); if(num_elms == 1){ diff --git a/tests/unit/pipe/test_callback_bcast_collective.cc b/tests/unit/pipe/test_callback_bcast_collective.cc new file mode 100644 index 0000000000..adf2921541 --- /dev/null +++ b/tests/unit/pipe/test_callback_bcast_collective.cc @@ -0,0 +1,299 @@ +/* +//@HEADER +// ***************************************************************************** +// +// test_callback_bcast_collective.cc +// DARMA/vt => Virtual Transport +// +// Copyright 2019-2021 National Technology & Engineering Solutions of Sandia, LLC +// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S. +// Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact darma@sandia.gov +// +// ***************************************************************************** +//@HEADER +*/ + +#include + +#include "test_parallel_harness.h" +#include "data_message.h" +#include "test_helpers.h" + +#include "vt/context/context.h" +#include "vt/vrt/collection/manager.h" + +#include + +namespace vt { namespace tests { namespace unit { namespace bcast { + +using namespace vt; +using namespace vt::tests::unit; + +struct CallbackMsg : vt::Message { + CallbackMsg() = default; + explicit CallbackMsg(Callback<> in_cb) : cb_(in_cb) { } + + Callback<> cb_; +}; + +struct DataMsg : vt::Message { + DataMsg() = default; + DataMsg(int in_a, int in_b, int in_c) : a(in_a), b(in_b), c(in_c) { } + int a = 0, b = 0, c = 0; +}; + +struct CallbackDataMsg : vt::Message { + CallbackDataMsg() = default; + explicit CallbackDataMsg(Callback in_cb) : cb_(in_cb) { } + + Callback cb_; +}; + +struct TestCallbackBcastCollective : TestParallelHarness { + static void testHandler(CallbackDataMsg* msg) { + msg->cb_.send(8,9,10); + } + static void testHandlerEmpty(CallbackMsg* msg) { + msg->cb_.send(); + } +}; + +struct TestCol : vt::Collection { + TestCol() = default; + + virtual ~TestCol() = default; + + void check1() { + if (theContext()->getNode() == 0) { + EXPECT_EQ(val, 29); + } else { + EXPECT_EQ(val, 17); + } + } + + void check2() { + if (theContext()->getNode() == 1) { + EXPECT_EQ(val, 13); + } else { + EXPECT_EQ(val, 17); + } + } + + void check3() { + if (theContext()->getNode() == 0) { + EXPECT_EQ(val, 15); + } else { + EXPECT_EQ(val, 17); + } + } + + void check4() { + if (theContext()->getNode() == 0) { + EXPECT_EQ(val, 21); + } else { + EXPECT_EQ(val, 17); + } + } + + void check5() { + if (theContext()->getNode() == 1) { + EXPECT_EQ(val, 99); + } else { + EXPECT_EQ(val, 17); + } + } + + void cb1(DataMsg* msg) { + EXPECT_EQ(msg->a, 8); + EXPECT_EQ(msg->b, 9); + EXPECT_EQ(msg->c, 10); + val = 29; + } + + void cb2(DataMsg* msg) { + EXPECT_EQ(msg->a, 8); + EXPECT_EQ(msg->b, 9); + EXPECT_EQ(msg->c, 10); + val = 13; + } + + void cb3(int a, int b) { + EXPECT_EQ(a, 8); + EXPECT_EQ(b, 9); + val = 15; + } + + void cb4(std::string str, int a) { + EXPECT_EQ(a, 8); + EXPECT_EQ(str, "hello"); + val = 21; + } + +public: + int32_t val = 17; +}; + +static void cb5(TestCol* col, DataMsg* msg) { + EXPECT_EQ(msg->a, 8); + EXPECT_EQ(msg->b, 9); + EXPECT_EQ(msg->c, 10); + col->val = 99; +} + +TEST_F(TestCallbackBcastCollective, test_callback_bcast_collective_1) { + auto const this_node = theContext()->getNode(); + auto const range = Index1D(32); + + vt::CollectionProxy proxy; + + if (this_node == 0) { + proxy = theCollection()->construct( + range, "test_callback_bcast_collective_1" + ); + } + + runInEpochCollective([&]{ + if (this_node == 0) { + auto cb = theCB()->makeBcastCollective<&TestCol::cb1>(proxy); + cb.send(8,9,10); + } + }); + + runInEpochCollective([&]{ + if (this_node == 0) { + proxy.broadcast<&TestCol::check1>(); + } + }); +} + +TEST_F(TestCallbackBcastCollective, test_callback_bcast_collective_2) { + SET_MIN_NUM_NODES_CONSTRAINT(2); + + auto const this_node = theContext()->getNode(); + auto const num_nodes = theContext()->getNumNodes(); + + auto const range = Index1D(32); + + vt::CollectionProxy proxy; + + if (this_node == 0) { + proxy = theCollection()->construct( + range, "test_callback_bcast_collective_2" + ); + } + + runInEpochCollective([&]{ + if (this_node == 0) { + auto next = this_node + 1 < num_nodes ? this_node + 1 : 0; + auto cb = theCB()->makeBcastCollective<&TestCol::cb2>(proxy); + auto msg = makeMessage(cb); + theMsg()->sendMsg(next, msg); + } + }); + + runInEpochCollective([&]{ + if (this_node == 0) { + proxy.broadcast<&TestCol::check2>(); + } + }); +} + +TEST_F(TestCallbackBcastCollective, test_callback_bcast_collective_param_1) { + auto const this_node = theContext()->getNode(); + + auto proxy = makeCollection("test_callback_bcast_collective_param_1") + .bulkInsert(Index1D(32)) + .wait(); + + runInEpochCollective([&]{ + if (this_node == 0) { + auto cb = theCB()->makeBcastCollective<&TestCol::cb3>(proxy); + cb.send(8, 9); + } + }); + + runInEpochCollective([&]{ + proxy.broadcastCollective<&TestCol::check3>(); + }); +} + +TEST_F(TestCallbackBcastCollective, test_callback_bcast_collective_param_2) { + auto const this_node = theContext()->getNode(); + + auto proxy = makeCollection("test_callback_bcast_collective_param_2") + .bulkInsert(Index1D(32)) + .wait(); + + runInEpochCollective([&]{ + if (this_node == 0) { + auto cb = theCB()->makeBcastCollective<&TestCol::cb4>(proxy); + cb.send("hello", 8); + } + }); + + runInEpochCollective([&]{ + proxy.broadcastCollective<&TestCol::check4>(); + }); +} + +TEST_F(TestCallbackBcastCollective, test_callback_bcast_collective_3) { + SET_MIN_NUM_NODES_CONSTRAINT(2); + + auto const this_node = theContext()->getNode(); + auto const num_nodes = theContext()->getNumNodes(); + + auto const range = Index1D(32); + + vt::CollectionProxy proxy; + + if (this_node == 0) { + proxy = theCollection()->construct( + range, "test_callback_bcast_collection_3" + ); + } + + runInEpochCollective([&]{ + if (this_node == 0) { + auto next = this_node + 1 < num_nodes ? this_node + 1 : 0; + auto cb = theCB()->makeBcastCollective(proxy); + auto msg = makeMessage(cb); + theMsg()->sendMsg(next, msg); + } + }); + + runInEpochCollective([&]{ + if (this_node == 0) { + proxy.broadcast<&TestCol::check5>(); + } + }); +} + +}}}} // end namespace vt::tests::unit::bcast