diff --git a/src/vt/vrt/collection/broadcast/broadcastable.h b/src/vt/vrt/collection/broadcast/broadcastable.h index 60d3002108..bc34ec87a8 100644 --- a/src/vt/vrt/collection/broadcast/broadcastable.h +++ b/src/vt/vrt/collection/broadcast/broadcastable.h @@ -80,6 +80,15 @@ struct Broadcastable : BaseProxyT { > messaging::PendingSend broadcast(Args&&... args) const; + template *f> + messaging::PendingSend broadcastCollective(MsgT* msg) const; + template *f> + messaging::PendingSend broadcastCollective(MsgSharedPtr msg) const; + template < + typename MsgT, ActiveColTypedFnType *f, typename... Args + > + messaging::PendingSend broadcastCollective(Args&&... args) const; + template f> messaging::PendingSend broadcastCollective(MsgT* msg) const; template f> diff --git a/src/vt/vrt/collection/broadcast/broadcastable.impl.h b/src/vt/vrt/collection/broadcast/broadcastable.impl.h index db03e40144..07b54075b4 100644 --- a/src/vt/vrt/collection/broadcast/broadcastable.impl.h +++ b/src/vt/vrt/collection/broadcast/broadcastable.impl.h @@ -97,7 +97,6 @@ messaging::PendingSend Broadcastable::broadcast(Args&&.. return broadcast(makeMessage(std::forward(args)...)); } - template template f> messaging::PendingSend Broadcastable::broadcast(MsgT* msg) const { @@ -105,6 +104,29 @@ messaging::PendingSend Broadcastable::broadcast(MsgT* ms return theCollection()->broadcastMsg(proxy,msg); } +template +template *f> +messaging::PendingSend +Broadcastable::broadcastCollective(MsgSharedPtr msg) const { + return broadcastCollective(msg.get()); +} + +template +template < + typename MsgT, ActiveColTypedFnType *f, typename... Args> +messaging::PendingSend +Broadcastable::broadcastCollective(Args&&... args) const { + return broadcastCollective(makeMessage(std::forward(args)...)); +} + +template +template *f> +messaging::PendingSend +Broadcastable::broadcastCollective(MsgT* msg) const { + auto proxy = this->getProxy(); + return theCollection()->broadcastMsgCollective(proxy, msg); +} + template template f> messaging::PendingSend diff --git a/src/vt/vrt/collection/manager.h b/src/vt/vrt/collection/manager.h index c899f66d48..f1ad7b2479 100644 --- a/src/vt/vrt/collection/manager.h +++ b/src/vt/vrt/collection/manager.h @@ -904,6 +904,13 @@ struct CollectionManager bool instrument ); + template < + typename MsgT, + ActiveColTypedFnType *f> + messaging::PendingSend broadcastMsgCollective( + CollectionProxyWrapType const& proxy, + MsgT* msg, bool instrument = true); + template < typename MsgT, ActiveColMemberTypedFnType f> @@ -911,6 +918,10 @@ struct CollectionManager CollectionProxyWrapType const& proxy, MsgT* msg, bool instrument = true); + template + messaging::PendingSend broadcastMsgCollectiveImpl( + CollectionProxyWrapType const& proxy, MsgT* msg); + /** * \brief Broadcast a message with action function handler * diff --git a/src/vt/vrt/collection/manager.impl.h b/src/vt/vrt/collection/manager.impl.h index 10c2ed0052..c505b6a8ec 100644 --- a/src/vt/vrt/collection/manager.impl.h +++ b/src/vt/vrt/collection/manager.impl.h @@ -848,17 +848,43 @@ messaging::PendingSend CollectionManager::broadcastFromRoot(MsgT* raw_msg) { return ret; } +template < + typename MsgT, ActiveColTypedFnType* f +> +messaging::PendingSend CollectionManager::broadcastMsgCollective( + CollectionProxyWrapType const& proxy, + MsgT* msg, bool instrument +) { + using ColT = typename MsgT::CollectionType; + + msg->setVrtHandler(auto_registry::makeAutoHandlerCollection()); + msg->setMember(false); + + return broadcastMsgCollectiveImpl(proxy, msg); +} template < typename MsgT, - ActiveColMemberTypedFnType f> + ActiveColMemberTypedFnType f +> messaging::PendingSend CollectionManager::broadcastMsgCollective( CollectionProxyWrapType const& proxy, - MsgT* msg, bool instrument) { - + MsgT* msg, bool instrument +) { using ColT = typename MsgT::CollectionType; - using IndexT = typename ColT::IndexType; + msg->setVrtHandler( + auto_registry::makeAutoHandlerCollectionMem()); + msg->setMember(true); + + return broadcastMsgCollectiveImpl(proxy, msg); +} + +template +messaging::PendingSend CollectionManager::broadcastMsgCollectiveImpl( + CollectionProxyWrapType const& proxy, MsgT* msg +) { + using IndexT = typename ColT::IndexType; auto promoMsg = promoteMsg(msg); return messaging::PendingSend( @@ -867,23 +893,21 @@ messaging::PendingSend CollectionManager::broadcastMsgCollective( auto const node = theContext()->getNode(); auto col_msg = reinterpret_cast(msgIn.get()); - auto handler = - auto_registry::makeAutoHandlerCollectionMem(); - col_msg->setVrtHandler(handler); theMsg()->markAsCollectionMessage(col_msg); if (elm_holder) { elm_holder->foreach ( - [node, msgIn, col_msg, - elm_holder](IndexT const& idx, CollectionBase* base) { + [node, + col_msg](IndexT const& idx, CollectionBase* base) { auto const from = col_msg->getFromNode(); auto trace_event = trace::no_trace_event; auto const hand = col_msg->getVrtHandler(); + auto const member = col_msg->getMember(); collectionAutoMsgDeliver< ColT, IndexT, MsgT, typename MsgT::UserMsgType>( - col_msg, base, hand, true, from, trace_event); + col_msg, base, hand, member, from, trace_event); }); } }); diff --git a/tests/unit/collection/test_collection_group.extended.cc b/tests/unit/collection/test_collection_group.extended.cc index 2ab78a5df6..8a7f315b15 100644 --- a/tests/unit/collection/test_collection_group.extended.cc +++ b/tests/unit/collection/test_collection_group.extended.cc @@ -54,7 +54,7 @@ namespace vt { namespace tests { namespace unit { -static int32_t elemCounter = 0; +static int32_t elem_counter = 0; struct MyReduceMsg : collective::ReduceTMsg { explicit MyReduceMsg(int const in_num) @@ -81,22 +81,30 @@ struct ColA : Collection { auto cb = theCB()->makeBcast(proxy); auto reduce_msg = makeMessage(getIndex().x()); proxy.reduce>(reduce_msg.get(),cb); + reduce_test = true; } void memberHanlder(TestDataMsg* msg) { EXPECT_EQ(msg->value_, theContext()->getNode()); - --elemCounter; - finished = true; + --elem_counter; } virtual ~ColA() { - EXPECT_TRUE(finished); + if (reduce_test) { + EXPECT_TRUE(finished); + } } -private: + private: bool finished = false; + bool reduce_test = false; }; +void colHanlder( + ColA::TestDataMsg* msg, typename ColA::TestDataMsg::CollectionType* type) { + --elem_counter; +} + struct TestCollectionGroup : TestParallelHarness { }; @@ -114,16 +122,49 @@ TEST_F(TestCollectionGroup, test_collection_group_2){ auto const my_node = theContext()->getNode(); auto const range = Index1D(8); - auto const proxy = theCollection()->constructCollective( - range, [](vt::Index1D idx) { - ++elemCounter; + auto const proxy = + theCollection()->constructCollective(range, [](vt::Index1D idx) { + ++elem_counter; return std::make_unique(); }); + const auto numElems = elem_counter; auto msg = ::vt::makeMessage(my_node); + proxy.broadcastCollective(msg.get()); + EXPECT_EQ(elem_counter, 0); + + proxy.broadcastCollective(msg); + EXPECT_EQ(elem_counter, -numElems); + + proxy.broadcastCollective< + ColA::TestDataMsg, &ColA::memberHanlder, ColA::TestDataMsg>(my_node); + EXPECT_EQ(elem_counter, -2 * numElems); +} + +TEST_F(TestCollectionGroup, test_collection_group_3){ + elem_counter = 0; + auto const my_node = theContext()->getNode(); + + auto const range = Index1D(8); + auto const proxy = + theCollection()->constructCollective(range, [](vt::Index1D idx) { + ++elem_counter; + return std::make_unique(); + }); + + const auto numElems = elem_counter; + auto msg = ::vt::makeMessage(my_node); + + proxy.broadcastCollective(msg.get()); + EXPECT_EQ(elem_counter, 0); + + proxy.broadcastCollective(msg); + EXPECT_EQ(elem_counter, -numElems); - EXPECT_EQ(elemCounter, 0); + proxy.broadcastCollective< + ColA::TestDataMsg, &ColA::memberHanlder, ColA::TestDataMsg>(my_node); + EXPECT_EQ(elem_counter, -2 * numElems); } }}} // end namespace vt::tests::unit