diff --git a/src/vt/collective/reduce/reduce.h b/src/vt/collective/reduce/reduce.h index e38f18888b..3e3f46e3b2 100644 --- a/src/vt/collective/reduce/reduce.h +++ b/src/vt/collective/reduce/reduce.h @@ -108,6 +108,23 @@ struct Reduce : virtual collective::tree::Tree { */ detail::ReduceStamp generateNextID(); + /** + * \brief Reduce a message up the tree, possibly delayed through a pending send + * + * \param[in] root the root node where the final handler provides the result + * \param[in] msg the message to reduce on this node + * \param[in] id the reduction stamp (optional), provided if out-of-order + * \param[in] num_contrib number of expected contributions from this node + * + * \return the pending send corresponding to the reduce + */ + template * f> + PendingSendType reduce( + NodeType root, MsgT* const msg, + detail::ReduceStamp id = detail::ReduceStamp{}, + ReduceNumType num_contrib = 1 + ); + /** * \brief Reduce a message up the tree * @@ -119,7 +136,7 @@ struct Reduce : virtual collective::tree::Tree { * \return the next reduction stamp */ template * f> - detail::ReduceStamp reduce( + detail::ReduceStamp reduceImmediate( NodeType root, MsgT* const msg, detail::ReduceStamp id = detail::ReduceStamp{}, ReduceNumType num_contrib = 1 @@ -143,7 +160,31 @@ struct Reduce : virtual collective::tree::Tree { MsgT, OpT, collective::reduce::operators::ReduceCallback > > - detail::ReduceStamp reduce( + PendingSendType reduce( + NodeType const& root, MsgT* msg, Callback cb, + detail::ReduceStamp id = detail::ReduceStamp{}, + ReduceNumType const& num_contrib = 1 + ); + + /** + * \brief Reduce a message up the tree + * + * \param[in] root the root node where the final handler provides the result + * \param[in] msg the message to reduce on this node + * \param[in] cb the callback to trigger on the root node + * \param[in] id the reduction stamp (optional), provided if out-of-order + * \param[in] num_contrib number of expected contributions from this node + * + * \return the next reduction stamp + */ + template < + typename OpT, + typename MsgT, + ActiveTypedFnType *f = MsgT::template msgHandler< + MsgT, OpT, collective::reduce::operators::ReduceCallback + > + > + detail::ReduceStamp reduceImmediate( NodeType const& root, MsgT* msg, Callback cb, detail::ReduceStamp id = detail::ReduceStamp{}, ReduceNumType const& num_contrib = 1 @@ -165,7 +206,29 @@ struct Reduce : virtual collective::tree::Tree { typename MsgT, ActiveTypedFnType *f = MsgT::template msgHandler > - detail::ReduceStamp reduce( + PendingSendType reduce( + NodeType const& root, MsgT* msg, + detail::ReduceStamp id = detail::ReduceStamp{}, + ReduceNumType const& num_contrib = 1 + ); + + /** + * \brief Reduce a message up the tree with a target function on the root node + * + * \param[in] root the root node where the final handler provides the result + * \param[in] msg the message to reduce on this node + * \param[in] id the reduction stamp (optional), provided if out-of-order + * \param[in] num_contrib number of expected contributions from this node + * + * \return the next reduction stamp + */ + template < + typename OpT, + typename FunctorT, + typename MsgT, + ActiveTypedFnType *f = MsgT::template msgHandler + > + detail::ReduceStamp reduceImmediate( NodeType const& root, MsgT* msg, detail::ReduceStamp id = detail::ReduceStamp{}, ReduceNumType const& num_contrib = 1 diff --git a/src/vt/collective/reduce/reduce.impl.h b/src/vt/collective/reduce/reduce.impl.h index cdf16080ac..1f4728df17 100644 --- a/src/vt/collective/reduce/reduce.impl.h +++ b/src/vt/collective/reduce/reduce.impl.h @@ -80,7 +80,7 @@ void Reduce::reduceRootRecv(MsgT* msg) { } template *f> -detail::ReduceStamp Reduce::reduce( +Reduce::PendingSendType Reduce::reduce( NodeType const& root, MsgT* msg, Callback cb, detail::ReduceStamp id, ReduceNumType const& num_contrib ) { @@ -88,18 +88,48 @@ detail::ReduceStamp Reduce::reduce( return reduce(root,msg,id,num_contrib); } +template *f> +detail::ReduceStamp Reduce::reduceImmediate( + NodeType const& root, MsgT* msg, Callback cb, detail::ReduceStamp id, + ReduceNumType const& num_contrib +) { + msg->setCallback(cb); + return reduceImmediate(root,msg,id,num_contrib); +} + template < typename OpT, typename FunctorT, typename MsgT, ActiveTypedFnType *f > -detail::ReduceStamp Reduce::reduce( +Reduce::PendingSendType Reduce::reduce( NodeType const& root, MsgT* msg, detail::ReduceStamp id, ReduceNumType const& num_contrib ) { return reduce(root,msg,id,num_contrib); } +template < + typename OpT, typename FunctorT, typename MsgT, ActiveTypedFnType *f +> +detail::ReduceStamp Reduce::reduceImmediate( + NodeType const& root, MsgT* msg, detail::ReduceStamp id, + ReduceNumType const& num_contrib +) { + return reduceImmediate(root,msg,id,num_contrib); +} + +template * f> +Reduce::PendingSendType Reduce::reduce( + NodeType root, MsgT* const msg, detail::ReduceStamp id, + ReduceNumType num_contrib +) { + auto msg_ptr = promoteMsg(msg); + return PendingSendType{theMsg()->getEpochContextMsg(msg_ptr), [=](){ + reduceImmediate(root, msg_ptr.get(), id, num_contrib); + } }; +} + template * f> -detail::ReduceStamp Reduce::reduce( +detail::ReduceStamp Reduce::reduceImmediate( NodeType root, MsgT* const msg, detail::ReduceStamp id, ReduceNumType num_contrib ) { diff --git a/src/vt/messaging/pending_send.cc b/src/vt/messaging/pending_send.cc index 30bb738c8e..0b268fd150 100644 --- a/src/vt/messaging/pending_send.cc +++ b/src/vt/messaging/pending_send.cc @@ -47,12 +47,6 @@ namespace vt { namespace messaging { -PendingSend::PendingSend( - MsgSharedPtr const& in_msg, ByteType const& in_msg_size) - : msg_(in_msg.toVirtual()), msg_size_(in_msg_size) { - produceMsg(); -} - PendingSend::PendingSend(EpochType ep, EpochActionType const& in_action) : epoch_action_{in_action}, epoch_produced_(ep) { if (epoch_produced_ != no_epoch) { diff --git a/src/vt/objgroup/manager.impl.h b/src/vt/objgroup/manager.impl.h index f3e0627cab..43c196519d 100644 --- a/src/vt/objgroup/manager.impl.h +++ b/src/vt/objgroup/manager.impl.h @@ -295,7 +295,7 @@ ObjGroupManager::PendingSendType ObjGroupManager::reduce( auto const objgroup = proxy.getProxy(); auto r = theCollective()->getReducerObjGroup(objgroup); - r->template reduce(root, msg.get(), stamp); + return r->template reduce(root, msg.get(), stamp); } template diff --git a/src/vt/vrt/collection/manager.impl.h b/src/vt/vrt/collection/manager.impl.h index 39c7e15e10..770d80f766 100644 --- a/src/vt/vrt/collection/manager.impl.h +++ b/src/vt/vrt/collection/manager.impl.h @@ -1054,7 +1054,7 @@ messaging::PendingSend CollectionManager::broadcastMsgUntypedHandler( } template *f> -void CollectionManager::reduceMsgExpr( +messaging::PendingSend CollectionManager::reduceMsgExpr( CollectionProxyWrapType const& proxy, MsgT *const raw_msg, ReduceIdxFuncType expr_fn, ReduceStamp stamp, NodeType root @@ -1071,7 +1071,7 @@ void CollectionManager::reduceMsgExpr( auto const col_proxy = proxy.getProxy(); auto const cur_epoch = theMsg()->getEpochContextMsg(msg); - bufferOpOrExecute( + return bufferOpOrExecute( col_proxy, BufferTypeEnum::Reduce, static_cast( @@ -1115,7 +1115,7 @@ void CollectionManager::reduceMsgExpr( r = theCollective()->getReducerVrtProxy(col_proxy); } - auto ret_stamp = r->reduce(root_node, msg.get(), cur_stamp, num_elms); + auto ret_stamp = r->reduceImmediate(root_node, msg.get(), cur_stamp, num_elms); vt_debug_print( vrt_coll, node, @@ -1143,7 +1143,7 @@ void CollectionManager::reduceMsgExpr( } template *f> -void CollectionManager::reduceMsg( +messaging::PendingSend CollectionManager::reduceMsg( CollectionProxyWrapType const& proxy, MsgT *const msg, ReduceStamp stamp, NodeType root ) { @@ -1151,7 +1151,7 @@ void CollectionManager::reduceMsg( } template *f> -void CollectionManager::reduceMsg( +messaging::PendingSend CollectionManager::reduceMsg( CollectionProxyWrapType const& proxy, MsgT *const msg, ReduceStamp stamp, typename ColT::IndexType const& idx ) { @@ -1159,7 +1159,7 @@ void CollectionManager::reduceMsg( } template *f> -void CollectionManager::reduceMsgExpr( +messaging::PendingSend CollectionManager::reduceMsgExpr( CollectionProxyWrapType const& proxy, MsgT *const msg, ReduceIdxFuncType expr_fn, ReduceStamp stamp, typename ColT::IndexType const& idx diff --git a/tests/unit/termination/test_term_chaining.cc b/tests/unit/termination/test_term_chaining.cc index 5c3b79ca73..63c1bc05e2 100644 --- a/tests/unit/termination/test_term_chaining.cc +++ b/tests/unit/termination/test_term_chaining.cc @@ -157,26 +157,26 @@ struct TestTermChaining : TestParallelHarness { if (0 == node) { EpochType epoch1 = theTerm()->makeEpochRooted(); vt::theMsg()->pushEpoch(epoch1); - auto msg = makeSharedMessage(); + auto msg = makeMessage(); chain.add( - epoch1, theMsg()->sendMsg(1, msg)); + epoch1, theMsg()->sendMsg(1, msg.get())); vt::theMsg()->popEpoch(epoch1); vt::theTerm()->finishedEpoch(epoch1); } EpochType epoch2 = theTerm()->makeEpochCollective(); vt::theMsg()->pushEpoch(epoch2); - auto msg2 = makeSharedMessage(theContext()->getNode()); - chain.add(epoch2, theCollective()->reduce(0, msg2)); + auto msg2 = makeMessage(theContext()->getNode()); + chain.add(epoch2, theCollective()->global()->reduce(0, msg2.get())); vt::theMsg()->popEpoch(epoch2); vt::theTerm()->finishedEpoch(epoch2); // Broadcast from both nodes, bcast wont send to itself EpochType epoch3 = theTerm()->makeEpochRooted(); vt::theMsg()->pushEpoch(epoch3); - auto msg3 = makeSharedMessage(); + auto msg3 = makeMessage(); chain.add( - epoch3, theMsg()->broadcastMsg(msg3)); + epoch3, theMsg()->broadcastMsg(msg3.get())); vt::theMsg()->popEpoch(epoch3); vt::theTerm()->finishedEpoch(epoch3); @@ -188,15 +188,15 @@ struct TestTermChaining : TestParallelHarness { EpochType epoch2 = theTerm()->makeEpochRooted(); vt::theMsg()->pushEpoch(epoch2); - auto msg2 = makeSharedMessage(theContext()->getNode()); - chain.add(epoch2, theCollective()->reduce(0, msg2)); + auto msg2 = makeMessage(theContext()->getNode()); + chain.add(epoch2, theCollective()->global()->reduce(0, msg2.get())); vt::theMsg()->popEpoch(epoch2); vt::theTerm()->finishedEpoch(epoch2); EpochType epoch3 = theTerm()->makeEpochRooted(); vt::theMsg()->pushEpoch(epoch3); - auto msg3 = makeSharedMessage(); - chain.add(epoch3, theMsg()->broadcastMsg(msg3)); + auto msg3 = makeMessage(); + chain.add(epoch3, theMsg()->broadcastMsg(msg3.get())); vt::theMsg()->popEpoch(epoch3); vt::theTerm()->finishedEpoch(epoch3);