Skip to content

Commit

Permalink
#702: rebase onto new ReduceStamp API
Browse files Browse the repository at this point in the history
  • Loading branch information
nmm0 committed Jul 15, 2020
1 parent c02e394 commit fe927c3
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 29 deletions.
69 changes: 66 additions & 3 deletions src/vt/collective/reduce/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,23 @@ struct Reduce : virtual collective::tree::Tree {
*/
detail::ReduceStamp generateNextID();

/**
* \brief Reduce a message up the tree, possibly delayed through a pending send
*
* \param[in] root the root node where the final handler provides the result
* \param[in] msg the message to reduce on this node
* \param[in] id the reduction stamp (optional), provided if out-of-order
* \param[in] num_contrib number of expected contributions from this node
*
* \return the pending send corresponding to the reduce
*/
template <typename MsgT, ActiveTypedFnType<MsgT>* f>
PendingSendType reduce(
NodeType root, MsgT* const msg,
detail::ReduceStamp id = detail::ReduceStamp{},
ReduceNumType num_contrib = 1
);

/**
* \brief Reduce a message up the tree
*
Expand All @@ -119,7 +136,7 @@ struct Reduce : virtual collective::tree::Tree {
* \return the next reduction stamp
*/
template <typename MsgT, ActiveTypedFnType<MsgT>* f>
detail::ReduceStamp reduce(
detail::ReduceStamp reduceImmediate(
NodeType root, MsgT* const msg,
detail::ReduceStamp id = detail::ReduceStamp{},
ReduceNumType num_contrib = 1
Expand All @@ -143,7 +160,31 @@ struct Reduce : virtual collective::tree::Tree {
MsgT, OpT, collective::reduce::operators::ReduceCallback<MsgT>
>
>
detail::ReduceStamp reduce(
PendingSendType reduce(
NodeType const& root, MsgT* msg, Callback<MsgT> cb,
detail::ReduceStamp id = detail::ReduceStamp{},
ReduceNumType const& num_contrib = 1
);

/**
* \brief Reduce a message up the tree
*
* \param[in] root the root node where the final handler provides the result
* \param[in] msg the message to reduce on this node
* \param[in] cb the callback to trigger on the root node
* \param[in] id the reduction stamp (optional), provided if out-of-order
* \param[in] num_contrib number of expected contributions from this node
*
* \return the next reduction stamp
*/
template <
typename OpT,
typename MsgT,
ActiveTypedFnType<MsgT> *f = MsgT::template msgHandler<
MsgT, OpT, collective::reduce::operators::ReduceCallback<MsgT>
>
>
detail::ReduceStamp reduceImmediate(
NodeType const& root, MsgT* msg, Callback<MsgT> cb,
detail::ReduceStamp id = detail::ReduceStamp{},
ReduceNumType const& num_contrib = 1
Expand All @@ -165,7 +206,29 @@ struct Reduce : virtual collective::tree::Tree {
typename MsgT,
ActiveTypedFnType<MsgT> *f = MsgT::template msgHandler<MsgT, OpT, FunctorT>
>
detail::ReduceStamp reduce(
PendingSendType reduce(
NodeType const& root, MsgT* msg,
detail::ReduceStamp id = detail::ReduceStamp{},
ReduceNumType const& num_contrib = 1
);

/**
* \brief Reduce a message up the tree with a target function on the root node
*
* \param[in] root the root node where the final handler provides the result
* \param[in] msg the message to reduce on this node
* \param[in] id the reduction stamp (optional), provided if out-of-order
* \param[in] num_contrib number of expected contributions from this node
*
* \return the next reduction stamp
*/
template <
typename OpT,
typename FunctorT,
typename MsgT,
ActiveTypedFnType<MsgT> *f = MsgT::template msgHandler<MsgT, OpT, FunctorT>
>
detail::ReduceStamp reduceImmediate(
NodeType const& root, MsgT* msg,
detail::ReduceStamp id = detail::ReduceStamp{},
ReduceNumType const& num_contrib = 1
Expand Down
36 changes: 33 additions & 3 deletions src/vt/collective/reduce/reduce.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,56 @@ void Reduce::reduceRootRecv(MsgT* msg) {
}

template <typename OpT, typename MsgT, ActiveTypedFnType<MsgT> *f>
detail::ReduceStamp Reduce::reduce(
Reduce::PendingSendType Reduce::reduce(
NodeType const& root, MsgT* msg, Callback<MsgT> cb, detail::ReduceStamp id,
ReduceNumType const& num_contrib
) {
msg->setCallback(cb);
return reduce<MsgT,f>(root,msg,id,num_contrib);
}

template <typename OpT, typename MsgT, ActiveTypedFnType<MsgT> *f>
detail::ReduceStamp Reduce::reduceImmediate(
NodeType const& root, MsgT* msg, Callback<MsgT> cb, detail::ReduceStamp id,
ReduceNumType const& num_contrib
) {
msg->setCallback(cb);
return reduceImmediate<MsgT,f>(root,msg,id,num_contrib);
}

template <
typename OpT, typename FunctorT, typename MsgT, ActiveTypedFnType<MsgT> *f
>
detail::ReduceStamp Reduce::reduce(
Reduce::PendingSendType Reduce::reduce(
NodeType const& root, MsgT* msg, detail::ReduceStamp id,
ReduceNumType const& num_contrib
) {
return reduce<MsgT,f>(root,msg,id,num_contrib);
}

template <
typename OpT, typename FunctorT, typename MsgT, ActiveTypedFnType<MsgT> *f
>
detail::ReduceStamp Reduce::reduceImmediate(
NodeType const& root, MsgT* msg, detail::ReduceStamp id,
ReduceNumType const& num_contrib
) {
return reduceImmediate<MsgT,f>(root,msg,id,num_contrib);
}

template <typename MsgT, ActiveTypedFnType<MsgT>* f>
Reduce::PendingSendType Reduce::reduce(
NodeType root, MsgT* const msg, detail::ReduceStamp id,
ReduceNumType num_contrib
) {
auto msg_ptr = promoteMsg(msg);
return PendingSendType{theMsg()->getEpochContextMsg(msg_ptr), [=](){
reduceImmediate<MsgT, f>(root, msg_ptr.get(), id, num_contrib);
} };
}

template <typename MsgT, ActiveTypedFnType<MsgT>* f>
detail::ReduceStamp Reduce::reduce(
detail::ReduceStamp Reduce::reduceImmediate(
NodeType root, MsgT* const msg, detail::ReduceStamp id,
ReduceNumType num_contrib
) {
Expand Down
6 changes: 0 additions & 6 deletions src/vt/messaging/pending_send.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@

namespace vt { namespace messaging {

PendingSend::PendingSend(
MsgSharedPtr<BaseMsgType> const& in_msg, ByteType const& in_msg_size)
: msg_(in_msg.toVirtual<BaseMsgType>()), msg_size_(in_msg_size) {
produceMsg();
}

PendingSend::PendingSend(EpochType ep, EpochActionType const& in_action)
: epoch_action_{in_action}, epoch_produced_(ep) {
if (epoch_produced_ != no_epoch) {
Expand Down
2 changes: 1 addition & 1 deletion src/vt/objgroup/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ ObjGroupManager::PendingSendType ObjGroupManager::reduce(
auto const objgroup = proxy.getProxy();

auto r = theCollective()->getReducerObjGroup(objgroup);
r->template reduce<MsgT,f>(root, msg.get(), stamp);
return r->template reduce<MsgT,f>(root, msg.get(), stamp);
}

template <typename ObjT>
Expand Down
12 changes: 6 additions & 6 deletions src/vt/vrt/collection/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,7 @@ messaging::PendingSend CollectionManager::broadcastMsgUntypedHandler(
}

template <typename ColT, typename MsgT, ActiveTypedFnType<MsgT> *f>
void CollectionManager::reduceMsgExpr(
messaging::PendingSend CollectionManager::reduceMsgExpr(
CollectionProxyWrapType<ColT> const& proxy,
MsgT *const raw_msg, ReduceIdxFuncType<typename ColT::IndexType> expr_fn,
ReduceStamp stamp, NodeType root
Expand All @@ -1069,7 +1069,7 @@ void CollectionManager::reduceMsgExpr(
auto const col_proxy = proxy.getProxy();
auto const cur_epoch = theMsg()->getEpochContextMsg(msg);

bufferOpOrExecute<ColT>(
return bufferOpOrExecute<ColT>(
col_proxy,
BufferTypeEnum::Reduce,
static_cast<BufferReleaseEnum>(
Expand Down Expand Up @@ -1113,7 +1113,7 @@ void CollectionManager::reduceMsgExpr(
r = theCollective()->getReducerVrtProxy(col_proxy);
}

auto ret_stamp = r->reduce<MsgT,f>(root_node, msg.get(), cur_stamp, num_elms);
auto ret_stamp = r->reduceImmediate<MsgT,f>(root_node, msg.get(), cur_stamp, num_elms);

debug_print(
vrt_coll, node,
Expand Down Expand Up @@ -1141,23 +1141,23 @@ void CollectionManager::reduceMsgExpr(
}

template <typename ColT, typename MsgT, ActiveTypedFnType<MsgT> *f>
void CollectionManager::reduceMsg(
messaging::PendingSend CollectionManager::reduceMsg(
CollectionProxyWrapType<ColT> const& proxy,
MsgT *const msg, ReduceStamp stamp, NodeType root
) {
return reduceMsgExpr<ColT,MsgT,f>(proxy,msg,nullptr,stamp,root);
}

template <typename ColT, typename MsgT, ActiveTypedFnType<MsgT> *f>
void CollectionManager::reduceMsg(
messaging::PendingSend CollectionManager::reduceMsg(
CollectionProxyWrapType<ColT> const& proxy,
MsgT *const msg, ReduceStamp stamp, typename ColT::IndexType const& idx
) {
return reduceMsgExpr<ColT,MsgT,f>(proxy,msg,nullptr,stamp,idx);
}

template <typename ColT, typename MsgT, ActiveTypedFnType<MsgT> *f>
void CollectionManager::reduceMsgExpr(
messaging::PendingSend CollectionManager::reduceMsgExpr(
CollectionProxyWrapType<ColT> const& proxy,
MsgT *const msg, ReduceIdxFuncType<typename ColT::IndexType> expr_fn,
ReduceStamp stamp, typename ColT::IndexType const& idx
Expand Down
20 changes: 10 additions & 10 deletions tests/unit/termination/test_term_chaining.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,26 +157,26 @@ struct TestTermChaining : TestParallelHarness {
if (0 == node) {
EpochType epoch1 = theTerm()->makeEpochRooted();
vt::theMsg()->pushEpoch(epoch1);
auto msg = makeSharedMessage<TestMsg>();
auto msg = makeMessage<TestMsg>();
chain.add(
epoch1, theMsg()->sendMsg<TestMsg, test_handler_reflector>(1, msg));
epoch1, theMsg()->sendMsg<TestMsg, test_handler_reflector>(1, msg.get()));
vt::theMsg()->popEpoch(epoch1);
vt::theTerm()->finishedEpoch(epoch1);
}

EpochType epoch2 = theTerm()->makeEpochCollective();
vt::theMsg()->pushEpoch(epoch2);
auto msg2 = makeSharedMessage<ChainReduceMsg>(theContext()->getNode());
chain.add(epoch2, theCollective()->reduce<ChainReduceMsg, test_handler_reduce>(0, msg2));
auto msg2 = makeMessage<ChainReduceMsg>(theContext()->getNode());
chain.add(epoch2, theCollective()->global()->reduce<ChainReduceMsg, test_handler_reduce>(0, msg2.get()));
vt::theMsg()->popEpoch(epoch2);
vt::theTerm()->finishedEpoch(epoch2);

// Broadcast from both nodes, bcast wont send to itself
EpochType epoch3 = theTerm()->makeEpochRooted();
vt::theMsg()->pushEpoch(epoch3);
auto msg3 = makeSharedMessage<TestMsg>();
auto msg3 = makeMessage<TestMsg>();
chain.add(
epoch3, theMsg()->broadcastMsg<TestMsg, test_handler_bcast>(msg3));
epoch3, theMsg()->broadcastMsg<TestMsg, test_handler_bcast>(msg3.get()));
vt::theMsg()->popEpoch(epoch3);
vt::theTerm()->finishedEpoch(epoch3);

Expand All @@ -188,15 +188,15 @@ struct TestTermChaining : TestParallelHarness {

EpochType epoch2 = theTerm()->makeEpochRooted();
vt::theMsg()->pushEpoch(epoch2);
auto msg2 = makeSharedMessage<ChainReduceMsg>(theContext()->getNode());
chain.add(epoch2, theCollective()->reduce<ChainReduceMsg, test_handler_reduce>(0, msg2));
auto msg2 = makeMessage<ChainReduceMsg>(theContext()->getNode());
chain.add(epoch2, theCollective()->global()->reduce<ChainReduceMsg, test_handler_reduce>(0, msg2.get()));
vt::theMsg()->popEpoch(epoch2);
vt::theTerm()->finishedEpoch(epoch2);

EpochType epoch3 = theTerm()->makeEpochRooted();
vt::theMsg()->pushEpoch(epoch3);
auto msg3 = makeSharedMessage<TestMsg>();
chain.add(epoch3, theMsg()->broadcastMsg<TestMsg, test_handler_bcast>(msg3));
auto msg3 = makeMessage<TestMsg>();
chain.add(epoch3, theMsg()->broadcastMsg<TestMsg, test_handler_bcast>(msg3.get()));
vt::theMsg()->popEpoch(epoch3);
vt::theTerm()->finishedEpoch(epoch3);

Expand Down

0 comments on commit fe927c3

Please sign in to comment.