diff --git a/examples/collection/transpose.cc b/examples/collection/transpose.cc index 7feaab0e70..71a89272f7 100644 --- a/examples/collection/transpose.cc +++ b/examples/collection/transpose.cc @@ -177,9 +177,6 @@ struct Block : vt::Collection { auto proxy = this->getCollectionProxy(); auto proxy_msg = vt::makeMessage(proxy.getProxy()); vt::theMsg()->broadcastMsg(proxy_msg); - // Invoke it locally: broadcast sends to all other nodes - auto proxy_msg_local = vt::makeMessage(proxy.getProxy()); - SetupGroup()(proxy_msg_local.get()); } } @@ -314,7 +311,7 @@ static void solveGroupSetup(vt::NodeType this_node, vt::VirtualProxyType coll_pr vt::theGroup()->newGroupCollective( is_even_node, [=](vt::GroupType group_id){ - fmt::print("Group is created: id={:x}\n", group_id); + fmt::print("{}: Group is created: id={:x}\n", this_node, group_id); if (this_node == 1) { auto msg = vt::makeMessage(coll_proxy); vt::envelopeSetGroup(msg->env, group_id); diff --git a/src/vt/collective/barrier/barrier.cc b/src/vt/collective/barrier/barrier.cc index db9bc50693..e5b1a3fd95 100644 --- a/src/vt/collective/barrier/barrier.cc +++ b/src/vt/collective/barrier/barrier.cc @@ -235,7 +235,6 @@ void Barrier::barrierUp( "barrierDown: barrier={}\n", barrier ); theMsg()->broadcastMsg(msg); - barrierDown(is_named, is_wait, barrier); } } } diff --git a/src/vt/group/global/group_default.cc b/src/vt/group/global/group_default.cc index aaa9cfb757..86dd89d09c 100644 --- a/src/vt/group/global/group_default.cc +++ b/src/vt/group/global/group_default.cc @@ -150,7 +150,7 @@ namespace vt { namespace group { namespace global { /*static*/ EventType DefaultGroup::broadcast( MsgSharedPtr const& base, NodeType const& from, - MsgSizeType const& size, bool const is_root + MsgSizeType const& size, bool const is_root, bool* const deliver ) { // By default use the default_group_->spanning_tree_ auto const& msg = base.get(); @@ -159,7 +159,8 @@ namespace vt { namespace group { namespace global { auto const& num_children = default_group_->spanning_tree_->getNumChildren(); auto const& node = theContext()->getNode(); NodeType const& root_node = 0; - bool const& send_to_root = is_root && node != root_node; + auto const is_root_of_tree = node == root_node; + bool const& send_to_root = is_root && !is_root_of_tree; EventType event = no_event; vt_debug_print( @@ -168,7 +169,7 @@ namespace vt { namespace group { namespace global { print_ptr(base.get()), size, from, dest, print_bool(is_root) ); - if (num_children > 0 || send_to_root) { + if (is_root || ((num_children > 0) && !is_root_of_tree) || send_to_root) { auto const& send_tag = static_cast( messaging::MPITag::ActiveMsgTag ); @@ -191,6 +192,10 @@ namespace vt { namespace group { namespace global { }); } + if (is_root && envelopeGetDeliverBcast(msg->env)) { + *deliver = true; + } + // If not the root of the spanning tree, send to the root to propagate to // the rest of the tree if (send_to_root) { diff --git a/src/vt/group/global/group_default.h b/src/vt/group/global/group_default.h index 2e1a84cbcc..6a545827fd 100644 --- a/src/vt/group/global/group_default.h +++ b/src/vt/group/global/group_default.h @@ -76,7 +76,7 @@ struct DefaultGroup { public: static EventType broadcast( MsgSharedPtr const& base, NodeType const& from, - MsgSizeType const& size, bool const is_root + MsgSizeType const& size, bool const is_root, bool* const deliver ); private: diff --git a/src/vt/group/group_manager.cc b/src/vt/group/group_manager.cc index 5c0d890e17..53cac21896 100644 --- a/src/vt/group/group_manager.cc +++ b/src/vt/group/group_manager.cc @@ -269,7 +269,7 @@ void GroupManager::initializeLocalGroup( // Deliver the message normally if it's not a the root of a broadcast *deliver = !root; if (group == default_group) { - return global::DefaultGroup::broadcast(base,from,size,root); + return global::DefaultGroup::broadcast(base,from,size,root,deliver); } else { auto const& is_collective_group = GroupIDBuilder::isCollective(group); if (is_collective_group) { diff --git a/src/vt/messaging/active.cc b/src/vt/messaging/active.cc index ae9198200a..04f9ca0927 100644 --- a/src/vt/messaging/active.cc +++ b/src/vt/messaging/active.cc @@ -186,6 +186,7 @@ EventType ActiveMessenger::sendMsgBytesWithPut( auto const& is_term = envelopeIsTerm(msg->env); auto const& is_put = envelopeIsPut(msg->env); auto const& is_put_packed = envelopeIsPackedPutType(msg->env); + auto const& is_bcast = envelopeIsBcast(msg->env); if (!is_term || vt_check_enabled(print_term_msgs)) { vt_debug_print( @@ -195,6 +196,11 @@ EventType ActiveMessenger::sendMsgBytesWithPut( ); } + vtWarnIf( + !(dest != theContext()->getNode() || is_bcast), + "Destination {} should != this node" + ); + MsgSizeType new_msg_size = msg_size; if (is_put && !is_put_packed) { @@ -388,10 +394,6 @@ EventType ActiveMessenger::sendMsgBytes( ); } - vtWarnIf( - !(dest != theContext()->getNode() || is_bcast), - "Destination {} should != this node" - ); vtAbortIf( dest >= theContext()->getNumNodes() || dest < 0, "Invalid destination: {}" ); @@ -520,10 +522,6 @@ SendInfo ActiveMessenger::sendData( data_ptr, num_bytes, dest, tag, send_tag ); - vtWarnIf( - dest == theContext()->getNode(), - "Destination {} should != this node" - ); vtAbortIf( dest >= theContext()->getNumNodes() || dest < 0, "Invalid destination: {}" diff --git a/src/vt/messaging/active.h b/src/vt/messaging/active.h index 494b83b496..c03d902b89 100644 --- a/src/vt/messaging/active.h +++ b/src/vt/messaging/active.h @@ -666,6 +666,7 @@ struct ActiveMessenger : runtime::component::PollableComponent * * \param[in] msg the message to broadcast * \param[in] msg_size the size of the message to send + * \param[in] deliver_to_sender whether msg should be delivered to sender * \param[in] tag the tag to put on the message * * \return the \c PendingSend for the sent message @@ -674,6 +675,7 @@ struct ActiveMessenger : runtime::component::PollableComponent PendingSendType broadcastMsgSz( MsgPtrThief msg, ByteType msg_size, + bool deliver_to_sender = true, TagType tag = no_tag ); @@ -683,6 +685,7 @@ struct ActiveMessenger : runtime::component::PollableComponent * \note Takes ownership of the supplied message. * * \param[in] msg the message to broadcast + * \param[in] deliver_to_sender whether msg should be delivered to sender * \param[in] tag the tag to put on the message * * \return the \c PendingSend for the sent message @@ -690,6 +693,7 @@ struct ActiveMessenger : runtime::component::PollableComponent template * f> PendingSendType broadcastMsg( MsgPtrThief msg, + bool deliver_to_sender = true, TagType tag = no_tag ); @@ -819,6 +823,7 @@ struct ActiveMessenger : runtime::component::PollableComponent * \note Takes ownership of the supplied message. * * \param[in] msg the message to broadcast + * \param[in] deliver_to_sender whether msg should be delivered to sender * \param[in] tag the optional tag to put on the message * * \return the \c PendingSend for the broadcast @@ -826,6 +831,7 @@ struct ActiveMessenger : runtime::component::PollableComponent template PendingSendType broadcastMsg( MsgPtrThief msg, + bool deliver_to_sender = true, TagType tag = no_tag ); @@ -890,6 +896,7 @@ struct ActiveMessenger : runtime::component::PollableComponent * \note Takes ownership of the supplied message. * * \param[in] msg the message to broadcast + * \param[in] deliver_to_sender whether msg should be delivered to sender * \param[in] tag the optional tag to put on the message * * \return the \c PendingSend for the broadcast @@ -900,6 +907,7 @@ struct ActiveMessenger : runtime::component::PollableComponent > PendingSendType broadcastMsg( MsgPtrThief msg, + bool deliver_to_sender = true, TagType tag = no_tag ); @@ -1023,6 +1031,7 @@ struct ActiveMessenger : runtime::component::PollableComponent * * \param[in] han the handler to invoke * \param[in] msg the message to broadcast + * \param[in] deliver_to_sender whether msg should be delivered to sender * \param[in] tag the optional tag to put on the message * * \return the \c PendingSend for the send @@ -1031,6 +1040,7 @@ struct ActiveMessenger : runtime::component::PollableComponent PendingSendType broadcastMsg( HandlerType han, MsgPtrThief msg, + bool deliver_to_sender = true, TagType tag = no_tag ); diff --git a/src/vt/messaging/active.impl.h b/src/vt/messaging/active.impl.h index 15c1146ed2..d32c63c0af 100644 --- a/src/vt/messaging/active.impl.h +++ b/src/vt/messaging/active.impl.h @@ -146,7 +146,9 @@ ActiveMessenger::PendingSendType ActiveMessenger::sendMsgSerializableImpl( MsgT* msgp = msg.get(); if (dest == broadcast_dest) { - return SerializedMessenger::broadcastSerialMsg(msgp,han); + return SerializedMessenger::broadcastSerialMsg( + msgp, han, envelopeGetDeliverBcast(msgp->env) + ); } else { return SerializedMessenger::sendSerialMsg(dest,msgp,han); } @@ -172,20 +174,20 @@ ActiveMessenger::PendingSendType ActiveMessenger::sendMsgCopyableImpl( MsgT* rawMsg = msg.get(); bool is_term = envelopeIsTerm(rawMsg->env); + const bool is_bcast = dest == broadcast_dest; if (!is_term || vt_check_enabled(print_term_msgs)) { vt_debug_print( active, node, - dest == broadcast_dest + is_bcast ? "broadcastMsg of ptr={}, type={}\n" : "sendMsg of ptr={}, type={}\n", print_ptr(rawMsg), typeid(MsgT).name() ); } - if (dest == broadcast_dest) { + if (is_bcast) { dest = theContext()->getNode(); - setBroadcastType(rawMsg->env); } if (msg_size == msgsize_not_specified) { msg_size = sizeof(MsgT); @@ -240,10 +242,14 @@ template * f> ActiveMessenger::PendingSendType ActiveMessenger::broadcastMsgSz( MsgPtrThief msg, ByteType msg_size, + bool deliver_to_sender, TagType tag ) { auto const han = auto_registry::makeAutoHandler(); MsgSharedPtr msgptr = msg.msg_; + + setBroadcastType(msgptr->env, deliver_to_sender); + return sendMsgImpl( broadcast_dest, han, msgptr, msg_size, tag ); @@ -252,10 +258,14 @@ ActiveMessenger::PendingSendType ActiveMessenger::broadcastMsgSz( template * f> ActiveMessenger::PendingSendType ActiveMessenger::broadcastMsg( MsgPtrThief msg, + bool deliver_to_sender, TagType tag ) { auto const han = auto_registry::makeAutoHandler(); MsgSharedPtr msgptr = msg.msg_; + + setBroadcastType(msgptr->env, deliver_to_sender); + return sendMsgImpl( broadcast_dest, han, msgptr, msgsize_not_specified, tag ); @@ -310,10 +320,12 @@ ActiveMessenger::PendingSendType ActiveMessenger::broadcastMsgAuto( template ActiveMessenger::PendingSendType ActiveMessenger::broadcastMsg( MsgPtrThief msg, + bool deliver_to_sender, TagType tag ) { auto const han = auto_registry::makeAutoHandler(); MsgSharedPtr msgptr = msg.msg_; + setBroadcastType(msgptr->env, deliver_to_sender); return sendMsgImpl( broadcast_dest, han, msgptr, msgsize_not_specified, tag ); @@ -333,10 +345,12 @@ ActiveMessenger::PendingSendType ActiveMessenger::sendMsg( template ActiveMessenger::PendingSendType ActiveMessenger::broadcastMsg( MsgPtrThief msg, + bool deliver_to_sender, TagType tag ) { auto const han = auto_registry::makeAutoHandlerFunctor(); MsgSharedPtr msgptr = msg.msg_; + setBroadcastType(msgptr->env, deliver_to_sender); return sendMsgImpl( broadcast_dest, han, msgptr, msgsize_not_specified, tag ); @@ -410,9 +424,11 @@ template ActiveMessenger::PendingSendType ActiveMessenger::broadcastMsg( HandlerType han, MsgPtrThief msg, + bool deliver_to_sender, TagType tag ) { MsgSharedPtr msgptr = msg.msg_; + setBroadcastType(msgptr->env, deliver_to_sender); return sendMsgImpl( broadcast_dest, han, msgptr, msgsize_not_specified, tag ); diff --git a/src/vt/messaging/envelope/envelope_base.h b/src/vt/messaging/envelope/envelope_base.h index 157f95c273..9d1c1d4f7c 100644 --- a/src/vt/messaging/envelope/envelope_base.h +++ b/src/vt/messaging/envelope/envelope_base.h @@ -98,6 +98,10 @@ struct ActiveEnvelope { /// True iff the message is considered locked. /// If locked, changes to the envelope will result in failure. bool is_locked : 1; + + // Used only for broadcast to default group + // Determines whether message should also be sent to the sender + bool deliver_bcast_to_sender : 1; }; }} /* end namespace vt::messaging */ diff --git a/src/vt/messaging/envelope/envelope_get.h b/src/vt/messaging/envelope/envelope_get.h index 82f7cfe150..8bc3715074 100644 --- a/src/vt/messaging/envelope/envelope_get.h +++ b/src/vt/messaging/envelope/envelope_get.h @@ -75,6 +75,16 @@ inline HandlerType envelopeGetHandler(Env const& env); template inline NodeType envelopeGetDest(Env const& env); +/** + * \brief Check whether bcast should be delivered to sender + * + * \param[in] env the envelope + * + * \return whether to deliver msg to sender + */ +template +inline bool envelopeGetDeliverBcast(Env const& env); + /** * \brief Get the group on an envelope * diff --git a/src/vt/messaging/envelope/envelope_get.impl.h b/src/vt/messaging/envelope/envelope_get.impl.h index c091f880a9..a78d90152b 100644 --- a/src/vt/messaging/envelope/envelope_get.impl.h +++ b/src/vt/messaging/envelope/envelope_get.impl.h @@ -60,6 +60,11 @@ inline NodeType envelopeGetDest(Env const& env) { return reinterpret_cast(&env)->dest; } +template +inline bool envelopeGetDeliverBcast(Env const& env) { + return reinterpret_cast(&env)->deliver_bcast_to_sender; +} + template inline GroupType envelopeGetGroup(Env& env) { return reinterpret_cast(&env)->group; diff --git a/src/vt/messaging/envelope/envelope_set.h b/src/vt/messaging/envelope/envelope_set.h index 46db0872ec..01b29a4004 100644 --- a/src/vt/messaging/envelope/envelope_set.h +++ b/src/vt/messaging/envelope/envelope_set.h @@ -91,9 +91,10 @@ inline void setTermType(Env& env); * \brief Set broadcast bit (changes how \c dest is interpreted) \c EnvBroadcast * * \param[in,out] env the envelope + * \param[in] deliver_to_sender whether the message should be delivered to sender */ template -inline void setBroadcastType(Env& env); +inline void setBroadcastType(Env& env, bool deliver_to_sender = true); /** * \brief Set epoch bit \c EnvEpoch diff --git a/src/vt/messaging/envelope/envelope_set.impl.h b/src/vt/messaging/envelope/envelope_set.impl.h index 4931660ccd..603dd087e9 100644 --- a/src/vt/messaging/envelope/envelope_set.impl.h +++ b/src/vt/messaging/envelope/envelope_set.impl.h @@ -76,9 +76,10 @@ inline void setTermType(Env& env) { } template -inline void setBroadcastType(Env& env) { +inline void setBroadcastType(Env& env, bool deliver_to_sender) { vtAssert(not envelopeIsLocked(env), "Envelope locked."); reinterpret_cast(&env)->type |= 1 << eEnvType::EnvBroadcast; + reinterpret_cast(&env)->deliver_bcast_to_sender = deliver_to_sender; } template diff --git a/src/vt/messaging/envelope/envelope_setup.h b/src/vt/messaging/envelope/envelope_setup.h index b9d05ac252..669838d0bd 100644 --- a/src/vt/messaging/envelope/envelope_setup.h +++ b/src/vt/messaging/envelope/envelope_setup.h @@ -92,6 +92,15 @@ inline void envelopeInitEmpty(Envelope& env); template inline void envelopeInitCopy(Env& env, Env const& src_env); +/** + * \brief Copy broadcast related data + * + * \param[in,out] env the target envelope to copy to + * \param[in] env the original envelope to use as a copy + */ +template +inline void envelopeCopyBcastData(Env& env, Env const& src_env); + /** * \brief Initialize/validate an envelope that has been received. * diff --git a/src/vt/messaging/envelope/envelope_setup.impl.h b/src/vt/messaging/envelope/envelope_setup.impl.h index 1ed6346017..a1bc2b0752 100644 --- a/src/vt/messaging/envelope/envelope_setup.impl.h +++ b/src/vt/messaging/envelope/envelope_setup.impl.h @@ -90,6 +90,14 @@ inline void envelopeInitCopy(Env& env, Env const& src_env) { envelopeSetIsLocked(env, false); } +template +inline void envelopeCopyBcastData(Env& env, Env const& src_env) { + envelopeSetIsLocked(env, false); + envelopeSetDest(env, envelopeGetDest(src_env)); + setBroadcastType(env); + envelopeSetIsLocked(env, true); +} + template inline void envelopeInitRecv(Env& env) { // Reset the local ref-count. The sender ref-count is not relevant. diff --git a/src/vt/messaging/envelope/envelope_test.h b/src/vt/messaging/envelope/envelope_test.h index d497a3ba54..13a78a1944 100644 --- a/src/vt/messaging/envelope/envelope_test.h +++ b/src/vt/messaging/envelope/envelope_test.h @@ -46,6 +46,7 @@ #define INCLUDED_MESSAGING_ENVELOPE_ENVELOPE_TEST_H #include "vt/config.h" +#include "vt/context/context.h" #include "vt/messaging/envelope/envelope_type.h" #include "vt/messaging/envelope/envelope_base.h" diff --git a/src/vt/objgroup/manager.static.h b/src/vt/objgroup/manager.static.h index db5e2151ed..8a42df2a6c 100644 --- a/src/vt/objgroup/manager.static.h +++ b/src/vt/objgroup/manager.static.h @@ -83,13 +83,7 @@ void invoke(messaging::MsgPtrThief msg, HandlerType han, NodeType dest_nod template void broadcast(MsgSharedPtr msg, HandlerType han) { - // Get the current epoch for the message - auto const cur_epoch = theMsg()->setupEpochMsg(msg); - // Broadcast the message - auto msg_hold = promoteMsg(msg.get()); // for scheduling - theMsg()->broadcastMsg(han, msg, no_tag); - // Schedule delivery on this node for the objgroup - scheduleMsg(msg_hold.template toVirtual(), han, cur_epoch); + theMsg()->broadcastMsg(han, msg); } }} /* end namespace vt::objgroup */ diff --git a/src/vt/objgroup/proxy/proxy_objgroup.impl.h b/src/vt/objgroup/proxy/proxy_objgroup.impl.h index 411a6155d7..c01b51a6b8 100644 --- a/src/vt/objgroup/proxy/proxy_objgroup.impl.h +++ b/src/vt/objgroup/proxy/proxy_objgroup.impl.h @@ -235,7 +235,7 @@ messaging::PendingSend Proxy::broadcast(Args&&... args) const { template * f> messaging::PendingSend Proxy::broadcastMsg(messaging::MsgPtrThief msg, TagType tag) const { - return theMsg()->broadcastMsg(msg, tag); + return theMsg()->broadcastMsg(msg, true, tag); } template diff --git a/src/vt/pipe/callback/callback_handler_bcast.h b/src/vt/pipe/callback/callback_handler_bcast.h index 28d95fc33d..c2506cea83 100644 --- a/src/vt/pipe/callback/callback_handler_bcast.h +++ b/src/vt/pipe/callback/callback_handler_bcast.h @@ -61,27 +61,10 @@ struct CallbackBcast : CallbackBase> { using SignalDataType = typename SignalType::DataType; using MessageType = MsgT; - explicit CallbackBcast(bool const in_include_root = false) - : include_root_(in_include_root) - { } - - template - void serialize(SerializerT& s) { - CallbackBase::serializer(s); - s | include_root_; - } - private: void trigger_(SignalDataType* data) override { theMsg()->broadcastMsg(data); - if (include_root_) { - auto nmsg = makeMessage(*data); - f(nmsg.get()); - } } - -private: - bool include_root_ = false; }; }}} /* end namespace vt::pipe::callback */ diff --git a/src/vt/pipe/callback/cb_union/cb_raw.h b/src/vt/pipe/callback/cb_union/cb_raw.h index 018e99a3fb..28613264ce 100644 --- a/src/vt/pipe/callback/cb_union/cb_raw.h +++ b/src/vt/pipe/callback/cb_union/cb_raw.h @@ -73,8 +73,8 @@ struct SendMsgCB : CallbackSendTypeless { struct BcastMsgCB : CallbackBcastTypeless { BcastMsgCB() = default; BcastMsgCB( - HandlerType const in_handler, bool const& in_include - ) : CallbackBcastTypeless(in_handler, in_include) + HandlerType const in_handler + ) : CallbackBcastTypeless(in_handler) { } }; diff --git a/src/vt/pipe/callback/cb_union/cb_raw_base.cc b/src/vt/pipe/callback/cb_union/cb_raw_base.cc index a128e81105..4bae290a30 100644 --- a/src/vt/pipe/callback/cb_union/cb_raw_base.cc +++ b/src/vt/pipe/callback/cb_union/cb_raw_base.cc @@ -55,9 +55,8 @@ CallbackRawBaseSingle::CallbackRawBaseSingle( { } CallbackRawBaseSingle::CallbackRawBaseSingle( - RawBcastMsgTagType, PipeType const& in_pipe, HandlerType const in_handler, - bool const& in_inc -) : pipe_(in_pipe), cb_(BcastMsgCB{in_handler,in_inc}) + RawBcastMsgTagType, PipeType const& in_pipe, HandlerType const in_handler +) : pipe_(in_pipe), cb_(BcastMsgCB{in_handler}) { } CallbackRawBaseSingle::CallbackRawBaseSingle( diff --git a/src/vt/pipe/callback/cb_union/cb_raw_base.h b/src/vt/pipe/callback/cb_union/cb_raw_base.h index 310da52ca1..c1f8790f4e 100644 --- a/src/vt/pipe/callback/cb_union/cb_raw_base.h +++ b/src/vt/pipe/callback/cb_union/cb_raw_base.h @@ -89,8 +89,7 @@ struct CallbackRawBaseSingle { NodeType const& in_node ); CallbackRawBaseSingle( - RawBcastMsgTagType, PipeType const& in_pipe, HandlerType const in_handler, - bool const& in_inc + RawBcastMsgTagType, PipeType const& in_pipe, HandlerType const in_handler ); CallbackRawBaseSingle(RawAnonTagType, PipeType const& in_pipe); CallbackRawBaseSingle(RawSendColMsgTagType, PipeType const& in_pipe); @@ -172,9 +171,8 @@ struct CallbackTyped : CallbackRawBaseSingle { ) : CallbackRawBaseSingle(RawSendMsgTag,in_pipe,in_handler,in_node) { } CallbackTyped( - RawBcastMsgTagType, PipeType const& in_pipe, HandlerType const in_handler, - bool const& in_inc - ) : CallbackRawBaseSingle(RawBcastMsgTag,in_pipe,in_handler,in_inc) + RawBcastMsgTagType, PipeType const& in_pipe, HandlerType const in_handler + ) : CallbackRawBaseSingle(RawBcastMsgTag,in_pipe,in_handler) { } CallbackTyped(RawAnonTagType, PipeType const& in_pipe) : CallbackRawBaseSingle(RawAnonTag,in_pipe) diff --git a/src/vt/pipe/callback/handler_bcast/callback_bcast.impl.h b/src/vt/pipe/callback/handler_bcast/callback_bcast.impl.h index 37af3d4e2f..c3ec0ed985 100644 --- a/src/vt/pipe/callback/handler_bcast/callback_bcast.impl.h +++ b/src/vt/pipe/callback/handler_bcast/callback_bcast.impl.h @@ -93,9 +93,6 @@ CallbackBcast::triggerDispatch(SignalDataType* data, PipeType const& pid) ); auto msg = makeMessage(pid); theMsg()->broadcastMsg(handler_, msg); - if (include_sender_) { - runnable::RunnableVoid::run(handler_,this_node); - } } template @@ -109,13 +106,6 @@ CallbackBcast::triggerDispatch(SignalDataType* data, PipeType const& pid) this_node, include_sender_ ); theMsg()->broadcastMsg(handler_, data); - auto msg_group = envelopeGetGroup(data->env); - bool const is_default = msg_group == default_group; - if (include_sender_ and is_default) { - auto nmsg = makeMessage(*data); - auto short_msg = nmsg.template to.get(); - runnable::Runnable::run(handler_,nullptr,short_msg,this_node); - } } }}} /* end namespace vt::pipe::callback */ diff --git a/src/vt/pipe/callback/handler_bcast/callback_bcast_tl.cc b/src/vt/pipe/callback/handler_bcast/callback_bcast_tl.cc index 471881ec80..ced55f77dd 100644 --- a/src/vt/pipe/callback/handler_bcast/callback_bcast_tl.cc +++ b/src/vt/pipe/callback/handler_bcast/callback_bcast_tl.cc @@ -55,23 +55,17 @@ namespace vt { namespace pipe { namespace callback { CallbackBcastTypeless::CallbackBcastTypeless( - HandlerType const in_handler, bool const& in_include -) : handler_(in_handler), include_sender_(in_include) + HandlerType const in_handler +) : handler_(in_handler) { } void CallbackBcastTypeless::triggerVoid(PipeType const& pipe) { auto const& this_node = theContext()->getNode(); vt_debug_print( - pipe, node, - "CallbackBcast: (void) trigger_: pipe={:x}, this_node={}, " - "include_sender_={}\n", - pipe, this_node, include_sender_ - ); + pipe, node, "CallbackBcast: (void) trigger_: pipe={:x}, this_node={}\n", + pipe, this_node); auto msg = makeMessage(pipe); theMsg()->broadcastMsg(handler_, msg); - if (include_sender_) { - runnable::RunnableVoid::run(handler_,this_node); - } } }}} /* end namespace vt::pipe::callback */ diff --git a/src/vt/pipe/callback/handler_bcast/callback_bcast_tl.h b/src/vt/pipe/callback/handler_bcast/callback_bcast_tl.h index ea47d47cda..0ccaa46b1f 100644 --- a/src/vt/pipe/callback/handler_bcast/callback_bcast_tl.h +++ b/src/vt/pipe/callback/handler_bcast/callback_bcast_tl.h @@ -60,19 +60,16 @@ struct CallbackBcastTypeless : CallbackBaseTL { CallbackBcastTypeless& operator=(CallbackBcastTypeless const&) = default; CallbackBcastTypeless( - HandlerType const in_handler, bool const& in_include + HandlerType const in_handler ); HandlerType getHandler() const { return handler_; } - bool getIncSender() const { return include_sender_; } template void serialize(SerializerT& s); bool operator==(CallbackBcastTypeless const& other) const { - return - other.include_sender_ == include_sender_ && - other.handler_ == handler_; + return other.handler_ == handler_; } public: @@ -82,7 +79,6 @@ struct CallbackBcastTypeless : CallbackBaseTL { private: HandlerType handler_ = uninitialized_handler; - bool include_sender_ = false; }; }}} /* end namespace vt::pipe::callback */ diff --git a/src/vt/pipe/callback/handler_bcast/callback_bcast_tl.impl.h b/src/vt/pipe/callback/handler_bcast/callback_bcast_tl.impl.h index 36d4942b9b..4ca895a5df 100644 --- a/src/vt/pipe/callback/handler_bcast/callback_bcast_tl.impl.h +++ b/src/vt/pipe/callback/handler_bcast/callback_bcast_tl.impl.h @@ -58,7 +58,6 @@ namespace vt { namespace pipe { namespace callback { template void CallbackBcastTypeless::serialize(SerializerT& s) { - s | include_sender_; s | handler_; } @@ -68,19 +67,11 @@ void CallbackBcastTypeless::trigger(MsgT* msg, PipeType const& pipe) { vt_debug_print( pipe, node, - "CallbackBcast: trigger_: pipe={:x}, this_node={}, include_sender_={}\n", - pipe, this_node, include_sender_ + "CallbackBcast: trigger_: pipe={:x}, this_node={}\n", + pipe, this_node ); - auto pmsg = promoteMsg(msg); - theMsg()->broadcastMsg(handler_, pmsg); - - auto msg_group = envelopeGetGroup(msg->env); - bool const is_default = msg_group == default_group; - if (include_sender_ and is_default) { - auto nmsg = makeMessage(*msg); // create copy (?) - runnable::Runnable::run(handler_, nullptr, nmsg.get(), this_node); - } + theMsg()->broadcastMsg(handler_, msg); } }}} /* end namespace vt::pipe::callback */ diff --git a/src/vt/pipe/pipe_manager.impl.h b/src/vt/pipe/pipe_manager.impl.h index 4c877e15c7..7b5c99ea68 100644 --- a/src/vt/pipe/pipe_manager.impl.h +++ b/src/vt/pipe/pipe_manager.impl.h @@ -128,17 +128,17 @@ Callback PipeManager::makeSend(objgroup::proxy::ProxyElm proxy) { template * f> Callback PipeManager::makeBcast() { - return makeCallbackSingleBcast(true); + return makeCallbackSingleBcast(); } template Callback PipeManager::makeBcast() { - return makeCallbackFunctorBcast(true); + return makeCallbackFunctorBcast(); } template Callback PipeManager::makeBcast() { - return makeCallbackFunctorBcastVoid(true); + return makeCallbackFunctorBcastVoid(); } template * f> diff --git a/src/vt/pipe/pipe_manager_tl.h b/src/vt/pipe/pipe_manager_tl.h index a2039b8fde..08fdba3b89 100644 --- a/src/vt/pipe/pipe_manager_tl.h +++ b/src/vt/pipe/pipe_manager_tl.h @@ -98,7 +98,7 @@ struct PipeManagerTL : virtual PipeManagerBase { CbkT makeCallbackSingleSend(NodeType const& node); template * f, typename CbkT = DefType> - CbkT makeCallbackSingleBcast(bool const& inc); + CbkT makeCallbackSingleBcast(); // Single active message functor-handler template < @@ -113,14 +113,14 @@ struct PipeManagerTL : virtual PipeManagerBase { typename T = typename util::FunctorExtractor::MessageType, typename CbkT = DefType > - CbkT makeCallbackFunctorBcast(bool const& inc); + CbkT makeCallbackFunctorBcast(); // Single active message functor-handler void param template > CbkT makeCallbackFunctorSendVoid(NodeType const& node); template > - CbkT makeCallbackFunctorBcastVoid(bool const& inc); + CbkT makeCallbackFunctorBcastVoid(); // Single active message anon func-handler template > @@ -188,7 +188,7 @@ struct PipeManagerTL : virtual PipeManagerBase { void addListener(CbkT const& cb, NodeType const& node); template * f, typename CbkT = DefType> - void addListenerBcast(CbkT const& cb, bool const& inc); + void addListenerBcast(CbkT const& cb); template < typename FunctorT, @@ -205,7 +205,7 @@ struct PipeManagerTL : virtual PipeManagerBase { typename T = typename util::FunctorExtractor::MessageType, typename CbkT = DefType > - void addListenerFunctorBcast(CbkT const& cb, bool const& inc); + void addListenerFunctorBcast(CbkT const& cb); }; }} /* end namespace vt::pipe */ diff --git a/src/vt/pipe/pipe_manager_tl.impl.h b/src/vt/pipe/pipe_manager_tl.impl.h index 189dd99bbc..c0c667688e 100644 --- a/src/vt/pipe/pipe_manager_tl.impl.h +++ b/src/vt/pipe/pipe_manager_tl.impl.h @@ -95,10 +95,10 @@ void PipeManagerTL::addListener(CallbackT const& cb, NodeType const& node) { } template * f, typename CallbackT> -void PipeManagerTL::addListenerBcast(CallbackT const& cb, bool const& inc) { +void PipeManagerTL::addListenerBcast(CallbackT const& cb) { auto const& han = auto_registry::makeAutoHandler(); addListenerAny( - cb.getPipe(), std::make_unique>(han,inc) + cb.getPipe(), std::make_unique>(han) ); } @@ -127,12 +127,12 @@ void PipeManagerTL::addListenerFunctorVoid( template void PipeManagerTL::addListenerFunctorBcast( - CallbackT const& cb, bool const& inc + CallbackT const& cb ) { using MsgT = typename util::FunctorExtractor::MessageType; auto const& han = auto_registry::makeAutoHandlerFunctor(); addListenerAny( - cb.getPipe(), std::make_unique>(han,inc) + cb.getPipe(), std::make_unique>(han) ); } @@ -297,11 +297,11 @@ CallbackT template * f, typename CallbackT> CallbackT -PipeManagerTL::makeCallbackSingleBcast(bool const& inc) { +PipeManagerTL::makeCallbackSingleBcast() { auto const& new_pipe_id = makePipeID(true,false); auto const& handler = auto_registry::makeAutoHandler(); auto cb = CallbackT( - callback::cbunion::RawBcastMsgTag,new_pipe_id,handler,inc + callback::cbunion::RawBcastMsgTag,new_pipe_id,handler ); return cb; } @@ -392,13 +392,13 @@ PipeManagerTL::makeCallbackFunctorSend(NodeType const& send_to_node) { template CallbackT -PipeManagerTL::makeCallbackFunctorBcast(bool const& inc) { +PipeManagerTL::makeCallbackFunctorBcast() { using MsgT = typename util::FunctorExtractor::MessageType; auto const& new_pipe_id = makePipeID(true,false); auto const& handler = auto_registry::makeAutoHandlerFunctor(); auto cb = CallbackT( - callback::cbunion::RawBcastMsgTag,new_pipe_id,handler,inc + callback::cbunion::RawBcastMsgTag,new_pipe_id,handler ); return cb; } @@ -416,11 +416,11 @@ PipeManagerTL::makeCallbackFunctorSendVoid(NodeType const& send_to_node) { template CallbackT -PipeManagerTL::makeCallbackFunctorBcastVoid(bool const& inc) { +PipeManagerTL::makeCallbackFunctorBcastVoid() { auto const& new_pipe_id = makePipeID(true,false); auto const& handler = auto_registry::makeAutoHandlerFunctor(); auto cb = CallbackT( - callback::cbunion::RawBcastMsgTag,new_pipe_id,handler,inc + callback::cbunion::RawBcastMsgTag,new_pipe_id,handler ); return cb; } diff --git a/src/vt/serialization/messaging/serialized_messenger.h b/src/vt/serialization/messaging/serialized_messenger.h index a58b785c3f..c6ed39dd1f 100644 --- a/src/vt/serialization/messaging/serialized_messenger.h +++ b/src/vt/serialization/messaging/serialized_messenger.h @@ -94,7 +94,7 @@ struct SerializedMessenger { template static messaging::PendingSend broadcastSerialMsg( - MsgT* msg, HandlerType han + MsgT* msg, HandlerType han, bool deliver_to_sender = true ); template diff --git a/src/vt/serialization/messaging/serialized_messenger.impl.h b/src/vt/serialization/messaging/serialized_messenger.impl.h index b3d52ba9fb..5fa29ef843 100644 --- a/src/vt/serialization/messaging/serialized_messenger.impl.h +++ b/src/vt/serialization/messaging/serialized_messenger.impl.h @@ -156,6 +156,12 @@ template auto msg_data = sys_msg->payload.data(); auto user_msg = deserializeFullMessage(msg_data); + // Keep bcast related data in user_msg since it's sometimes + // needed in the handler + if (envelopeIsBcast(sys_msg->env)) { + envelopeCopyBcastData(user_msg->env, sys_msg->env); + } + vt_debug_print( serial_msg, node, "payloadMsgHandler: group={:x}, msg={}, handler={}, bytes={}, " @@ -193,7 +199,7 @@ template template /*static*/ messaging::PendingSend SerializedMessenger::broadcastSerialMsg( - MsgT* msg_ptr, HandlerType han + MsgT* msg_ptr, HandlerType han, bool deliver_to_sender ) { using PayloadMsg = SerialEagerPayloadMsg; @@ -250,7 +256,7 @@ template ); theMsg()->markAsSerialMsgMessage(payload_msg); - return theMsg()->broadcastMsg(payload_msg); + return theMsg()->broadcastMsg(payload_msg, deliver_to_sender); } else { auto const& total_size = ptr_size + sys_size; diff --git a/src/vt/termination/termination.cc b/src/vt/termination/termination.cc index 899ca16875..ba2f217287 100644 --- a/src/vt/termination/termination.cc +++ b/src/vt/termination/termination.cc @@ -485,7 +485,7 @@ bool TerminationDetector::propagateEpoch(TermStateType& state) { if (is_term) { auto msg = makeMessage(state.getEpoch()); theMsg()->markAsTermMessage(msg); - theMsg()->broadcastMsg(msg); + theMsg()->broadcastMsg(msg, false); state.setTerminated(); @@ -514,7 +514,7 @@ bool TerminationDetector::propagateEpoch(TermStateType& state) { auto msg = makeMessage(state.getEpoch(), state.getCurWave()); theMsg()->markAsTermMessage(msg); - theMsg()->broadcastMsg(msg); + theMsg()->broadcastMsg(msg, false); } } @@ -595,7 +595,7 @@ void TerminationDetector::countsConstant(TermStateType& state) { if (not theConfig()->vt_no_detect_hang) { auto msg = makeMessage(); theMsg()->markAsTermMessage(msg.get()); - theMsg()->broadcastMsg(msg); + theMsg()->broadcastMsg(msg, false); hangCheckHandler(nullptr); } } @@ -612,7 +612,7 @@ void TerminationDetector::startEpochGraphBuild() { if (theConfig()->vt_epoch_graph_on_hang) { auto msg = makeMessage(); theMsg()->markAsTermMessage(msg.get()); - theMsg()->broadcastMsg(msg); + theMsg()->broadcastMsg(msg, false); buildLocalGraphHandler(nullptr); } } @@ -991,7 +991,7 @@ EpochType TerminationDetector::makeEpochRootedWave( */ auto msg = makeMessage(epoch); theMsg()->markAsTermMessage(msg); - theMsg()->broadcastMsg(msg); + theMsg()->broadcastMsg(msg, false); /* * Setup the new rooted epoch locally on the root node (this node) diff --git a/src/vt/vrt/collection/manager.impl.h b/src/vt/vrt/collection/manager.impl.h index 74cb18f406..5f4c791fba 100644 --- a/src/vt/vrt/collection/manager.impl.h +++ b/src/vt/vrt/collection/manager.impl.h @@ -618,10 +618,12 @@ template auto const& proxy = msg->proxy; theCollection()->constructed_.insert(proxy); theCollection()->addToState(proxy, BufferReleaseEnum::AfterFullyConstructed); + vt_debug_print( vrt_coll, node, "addToState: proxy={:x}, AfterCons\n", proxy ); + theCollection()->triggerReadyOps(proxy, BufferTypeEnum::Broadcast); theCollection()->triggerReadyOps(proxy, BufferTypeEnum::Send); theCollection()->triggerReadyOps(proxy, BufferTypeEnum::Reduce); @@ -1035,10 +1037,6 @@ messaging::PendingSend CollectionManager::broadcastFromRoot(MsgT* raw_msg) { msg ); - if (!send_group) { - collectionBcastHandler(msg_hold.get()); - } - theMsg()->popEpoch(cur_epoch); return ret; @@ -2171,6 +2169,7 @@ template vrt_coll, node, "addToState: proxy={:x}, AfterMeta\n", proxy ); + theCollection()->addToState(proxy, BufferReleaseEnum::AfterMetaDataKnown); theCollection()->triggerReadyOps(proxy, BufferTypeEnum::Send); theCollection()->triggerReadyOps(proxy, BufferTypeEnum::Broadcast); @@ -2294,7 +2293,7 @@ CollectionManager::constructMap( ); theMsg()->broadcastMsg>( - create_msg + create_msg, false ); auto create_msg_local = makeMessage( @@ -2461,7 +2460,7 @@ void CollectionManager::finishedInsertEpoch( theMsg()->markAsCollectionMessage(msg); theMsg()->broadcastMsg< UpdateInsertMsg,updateInsertEpochHandler - >(msg); + >(msg, false); /* * Start building the a new group for broadcasts and reductions over the @@ -2972,10 +2971,7 @@ void CollectionManager::destroy( auto msg = makeMessage(proxy, this_node); theMsg()->markAsCollectionMessage(msg); - auto msg_hold = promoteMsg(msg.get()); // keep after bcast theMsg()->broadcastMsg(msg); - - DestroyHandlers::destroyNow(msg_hold.get()); } template @@ -3284,6 +3280,7 @@ messaging::PendingSend CollectionManager::bufferOpOrExecute( VirtualProxyType proxy, BufferTypeEnum type, BufferReleaseEnum release, EpochType epoch, ActionPendingType action ) { + if (checkReady(proxy, release)) { theMsg()->pushEpoch(epoch); auto ps = action(); diff --git a/tests/unit/active/test_active_bcast_put.cc b/tests/unit/active/test_active_bcast_put.cc index aec1461539..ff3da48a0b 100644 --- a/tests/unit/active/test_active_bcast_put.cc +++ b/tests/unit/active/test_active_bcast_put.cc @@ -131,9 +131,7 @@ TEST_P(TestActiveBroadcastPut, test_type_safe_active_fn_bcast2) { } }); - if (my_node != root) { - ASSERT_TRUE(handler_count == num_msg_sent); - } + ASSERT_TRUE(handler_count == num_msg_sent); } // Spin here so test_vec does not go out of scope before the send completes diff --git a/tests/unit/active/test_active_broadcast.cc b/tests/unit/active/test_active_broadcast.cc index 09962eb9f5..358479d12a 100644 --- a/tests/unit/active/test_active_broadcast.cc +++ b/tests/unit/active/test_active_broadcast.cc @@ -109,9 +109,7 @@ TEST_P(TestActiveBroadcast, test_type_safe_active_fn_bcast2) { } }); - if (my_node != root) { - ASSERT_TRUE(handler_count == num_msg_sent); - } + ASSERT_TRUE(handler_count == num_msg_sent); } } diff --git a/tests/unit/location/test_location.cc b/tests/unit/location/test_location.cc index 1c4a4c3d06..100c83a8be 100644 --- a/tests/unit/location/test_location.cc +++ b/tests/unit/location/test_location.cc @@ -353,14 +353,14 @@ TYPED_TEST_P(TestLocationRoute, test_route_entity) { auto msg = vt::makeMessage(entity, my_node, is_long); vt::theMsg()->broadcastMsg>(msg); - while (msg_count < nb_nodes - 1) { vt::runScheduler(); } + vt::theSched()->runSchedulerWhile([&msg_count, nb_nodes]{ return msg_count < nb_nodes; }); vt_debug_print( location, node, "TestLocationRoute: all messages have been arrived\n" ); - EXPECT_EQ(msg_count, nb_nodes - 1); + EXPECT_EQ(msg_count, nb_nodes); } } diff --git a/tests/unit/memory/test_memory_lifetime.cc b/tests/unit/memory/test_memory_lifetime.cc index 06aec3d23a..e068ab8e03 100644 --- a/tests/unit/memory/test_memory_lifetime.cc +++ b/tests/unit/memory/test_memory_lifetime.cc @@ -149,7 +149,7 @@ TEST_F(TestMemoryLifetime, test_active_bcast_serial_lifetime) { }); EXPECT_EQ(SerialTrackMsg::alloc_count, 0); - EXPECT_EQ(local_count, num_msgs_sent*(num_nodes-1)); + EXPECT_EQ(local_count, num_msgs_sent*num_nodes); } } @@ -210,7 +210,7 @@ TEST_F(TestMemoryLifetime, test_active_bcast_normal_lifetime_msgptr) { } theTerm()->addAction([=]{ - EXPECT_EQ(local_count, num_msgs_sent*(num_nodes-1)); + EXPECT_EQ(local_count, num_msgs_sent*num_nodes); }); } } diff --git a/tests/unit/termination/test_term_chaining.cc b/tests/unit/termination/test_term_chaining.cc index 0c3f3eeb36..f70531b475 100644 --- a/tests/unit/termination/test_term_chaining.cc +++ b/tests/unit/termination/test_term_chaining.cc @@ -119,12 +119,19 @@ struct TestTermChaining : TestParallelHarness { } static void test_handler_bcast(TestMsg* msg) { + static auto visited = 0; + if (theContext()->getNode() == 0) { EXPECT_EQ(handler_count, 2); } else { EXPECT_EQ(handler_count, 12); } - handler_count = 3; + + ++visited; + + if (visited == 2) { + handler_count = 3; + } } static void start_chain() { @@ -170,7 +177,6 @@ struct TestTermChaining : TestParallelHarness { vt::theMsg()->popEpoch(epoch2); vt::theTerm()->finishedEpoch(epoch2); - // Broadcast from both nodes, bcast wont send to itself EpochType epoch3 = theTerm()->makeEpochRooted(); vt::theMsg()->pushEpoch(epoch3); auto msg3 = makeMessage();