Skip to content

Commit

Permalink
[#1024]: Add bcast collective with non-member message handler
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Sep 19, 2020
1 parent b46c1cf commit 920ce6b
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 20 deletions.
9 changes: 9 additions & 0 deletions src/vt/vrt/collection/broadcast/broadcastable.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ struct Broadcastable : BaseProxyT {
>
messaging::PendingSend broadcast(Args&&... args) const;

template <typename MsgT, ActiveColTypedFnType<MsgT, ColT> *f>
messaging::PendingSend broadcastCollective(MsgT* msg) const;
template <typename MsgT, ActiveColTypedFnType<MsgT, ColT> *f>
messaging::PendingSend broadcastCollective(MsgSharedPtr<MsgT> msg) const;
template <
typename MsgT, ActiveColTypedFnType<MsgT, ColT> *f, typename... Args
>
messaging::PendingSend broadcastCollective(Args&&... args) const;

template <typename MsgT, ActiveColMemberTypedFnType<MsgT, ColT> f>
messaging::PendingSend broadcastCollective(MsgT* msg) const;
template <typename MsgT, ActiveColMemberTypedFnType<MsgT, ColT> f>
Expand Down
24 changes: 23 additions & 1 deletion src/vt/vrt/collection/broadcast/broadcastable.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,36 @@ messaging::PendingSend Broadcastable<ColT,IndexT,BaseProxyT>::broadcast(Args&&..
return broadcast<MsgT,f>(makeMessage<MsgT>(std::forward<Args>(args)...));
}


template <typename ColT, typename IndexT, typename BaseProxyT>
template <typename MsgT, ActiveColMemberTypedFnType<MsgT, ColT> f>
messaging::PendingSend Broadcastable<ColT,IndexT,BaseProxyT>::broadcast(MsgT* msg) const {
auto proxy = this->getProxy();
return theCollection()->broadcastMsg<MsgT, f>(proxy,msg);
}

template <typename ColT, typename IndexT, typename BaseProxyT>
template <typename MsgT, ActiveColTypedFnType<MsgT, ColT> *f>
messaging::PendingSend
Broadcastable<ColT, IndexT, BaseProxyT>::broadcastCollective(MsgSharedPtr<MsgT> msg) const {
return broadcastCollective<MsgT, f>(msg.get());
}

template <typename ColT, typename IndexT, typename BaseProxyT>
template <
typename MsgT, ActiveColTypedFnType<MsgT, ColT> *f, typename... Args>
messaging::PendingSend
Broadcastable<ColT, IndexT, BaseProxyT>::broadcastCollective(Args&&... args) const {
return broadcastCollective<MsgT, f>(makeMessage<MsgT>(std::forward<Args>(args)...));
}

template <typename ColT, typename IndexT, typename BaseProxyT>
template <typename MsgT, ActiveColTypedFnType<MsgT, ColT> *f>
messaging::PendingSend
Broadcastable<ColT, IndexT, BaseProxyT>::broadcastCollective(MsgT* msg) const {
auto proxy = this->getProxy();
return theCollection()->broadcastMsgCollective<MsgT, f>(proxy, msg);
}

template <typename ColT, typename IndexT, typename BaseProxyT>
template <typename MsgT, ActiveColMemberTypedFnType<MsgT, ColT> f>
messaging::PendingSend
Expand Down
11 changes: 11 additions & 0 deletions src/vt/vrt/collection/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -904,13 +904,24 @@ struct CollectionManager
bool instrument
);

template <
typename MsgT,
ActiveColTypedFnType<MsgT,typename MsgT::CollectionType> *f>
messaging::PendingSend broadcastMsgCollective(
CollectionProxyWrapType<typename MsgT::CollectionType> const& proxy,
MsgT* msg, bool instrument = true);

template <
typename MsgT,
ActiveColMemberTypedFnType<MsgT, typename MsgT::CollectionType> f>
messaging::PendingSend broadcastMsgCollective(
CollectionProxyWrapType<typename MsgT::CollectionType> const& proxy,
MsgT* msg, bool instrument = true);

template <typename MsgT, typename ColT>
messaging::PendingSend broadcastMsgCollectiveImpl(
CollectionProxyWrapType<ColT> const& proxy, MsgT* msg);

/**
* \brief Broadcast a message with action function handler
*
Expand Down
44 changes: 34 additions & 10 deletions src/vt/vrt/collection/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -848,17 +848,43 @@ messaging::PendingSend CollectionManager::broadcastFromRoot(MsgT* raw_msg) {
return ret;
}

template <
typename MsgT, ActiveColTypedFnType<MsgT, typename MsgT::CollectionType>* f
>
messaging::PendingSend CollectionManager::broadcastMsgCollective(
CollectionProxyWrapType<typename MsgT::CollectionType> const& proxy,
MsgT* msg, bool instrument
) {
using ColT = typename MsgT::CollectionType;

msg->setVrtHandler(auto_registry::makeAutoHandlerCollection<ColT, MsgT, f>());
msg->setMember(false);

return broadcastMsgCollectiveImpl<MsgT, ColT>(proxy, msg);
}

template <
typename MsgT,
ActiveColMemberTypedFnType<MsgT, typename MsgT::CollectionType> f>
ActiveColMemberTypedFnType<MsgT, typename MsgT::CollectionType> f
>
messaging::PendingSend CollectionManager::broadcastMsgCollective(
CollectionProxyWrapType<typename MsgT::CollectionType> const& proxy,
MsgT* msg, bool instrument) {

MsgT* msg, bool instrument
) {
using ColT = typename MsgT::CollectionType;
using IndexT = typename ColT::IndexType;

msg->setVrtHandler(
auto_registry::makeAutoHandlerCollectionMem<ColT, MsgT, f>());
msg->setMember(true);

return broadcastMsgCollectiveImpl<MsgT, ColT>(proxy, msg);
}

template <typename MsgT, typename ColT>
messaging::PendingSend CollectionManager::broadcastMsgCollectiveImpl(
CollectionProxyWrapType<ColT> const& proxy, MsgT* msg
) {
using IndexT = typename ColT::IndexType;
auto promoMsg = promoteMsg(msg);

return messaging::PendingSend(
Expand All @@ -867,23 +893,21 @@ messaging::PendingSend CollectionManager::broadcastMsgCollective(
auto const node = theContext()->getNode();

auto col_msg = reinterpret_cast<MsgT*>(msgIn.get());
auto handler =
auto_registry::makeAutoHandlerCollectionMem<ColT, MsgT, f>();
col_msg->setVrtHandler(handler);

theMsg()->markAsCollectionMessage(col_msg);

if (elm_holder) {
elm_holder->foreach (
[node, msgIn, col_msg,
elm_holder](IndexT const& idx, CollectionBase<ColT, IndexT>* base) {
[node,
col_msg](IndexT const& idx, CollectionBase<ColT, IndexT>* base) {
auto const from = col_msg->getFromNode();
auto trace_event = trace::no_trace_event;
auto const hand = col_msg->getVrtHandler();
auto const member = col_msg->getMember();

collectionAutoMsgDeliver<
ColT, IndexT, MsgT, typename MsgT::UserMsgType>(
col_msg, base, hand, true, from, trace_event);
col_msg, base, hand, member, from, trace_event);
});
}
});
Expand Down
59 changes: 50 additions & 9 deletions tests/unit/collection/test_collection_group.extended.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

namespace vt { namespace tests { namespace unit {

static int32_t elemCounter = 0;
static int32_t elem_counter = 0;

struct MyReduceMsg : collective::ReduceTMsg<int> {
explicit MyReduceMsg(int const in_num)
Expand All @@ -81,22 +81,30 @@ struct ColA : Collection<ColA,Index1D> {
auto cb = theCB()->makeBcast<ColA, MyReduceMsg, &ColA::finishedReduce>(proxy);
auto reduce_msg = makeMessage<MyReduceMsg>(getIndex().x());
proxy.reduce<collective::PlusOp<int>>(reduce_msg.get(),cb);
reduce_test = true;
}

void memberHanlder(TestDataMsg* msg) {
EXPECT_EQ(msg->value_, theContext()->getNode());
--elemCounter;
finished = true;
--elem_counter;
}

virtual ~ColA() {
EXPECT_TRUE(finished);
if (reduce_test) {
EXPECT_TRUE(finished);
}
}

private:
private:
bool finished = false;
bool reduce_test = false;
};

void colHanlder(
ColA::TestDataMsg* msg, typename ColA::TestDataMsg::CollectionType* type) {
--elem_counter;
}

struct TestCollectionGroup : TestParallelHarness { };


Expand All @@ -114,16 +122,49 @@ TEST_F(TestCollectionGroup, test_collection_group_2){
auto const my_node = theContext()->getNode();

auto const range = Index1D(8);
auto const proxy = theCollection()->constructCollective<ColA>(
range, [](vt::Index1D idx) {
++elemCounter;
auto const proxy =
theCollection()->constructCollective<ColA>(range, [](vt::Index1D idx) {
++elem_counter;
return std::make_unique<ColA>();
});

const auto numElems = elem_counter;
auto msg = ::vt::makeMessage<ColA::TestDataMsg>(my_node);

proxy.broadcastCollective<ColA::TestDataMsg, &ColA::memberHanlder>(msg.get());
EXPECT_EQ(elem_counter, 0);

proxy.broadcastCollective<ColA::TestDataMsg, &ColA::memberHanlder>(msg);
EXPECT_EQ(elem_counter, -numElems);

proxy.broadcastCollective<
ColA::TestDataMsg, &ColA::memberHanlder, ColA::TestDataMsg>(my_node);
EXPECT_EQ(elem_counter, -2 * numElems);
}

TEST_F(TestCollectionGroup, test_collection_group_3){
elem_counter = 0;
auto const my_node = theContext()->getNode();

auto const range = Index1D(8);
auto const proxy =
theCollection()->constructCollective<ColA>(range, [](vt::Index1D idx) {
++elem_counter;
return std::make_unique<ColA>();
});

const auto numElems = elem_counter;
auto msg = ::vt::makeMessage<ColA::TestDataMsg>(my_node);

proxy.broadcastCollective<ColA::TestDataMsg, colHanlder>(msg.get());
EXPECT_EQ(elem_counter, 0);

proxy.broadcastCollective<ColA::TestDataMsg, &ColA::memberHanlder>(msg);
EXPECT_EQ(elem_counter, -numElems);

EXPECT_EQ(elemCounter, 0);
proxy.broadcastCollective<
ColA::TestDataMsg, &ColA::memberHanlder, ColA::TestDataMsg>(my_node);
EXPECT_EQ(elem_counter, -2 * numElems);
}

}}} // end namespace vt::tests::unit

0 comments on commit 920ce6b

Please sign in to comment.