Skip to content

Commit

Permalink
Merge pull request #1225 from DARMA-tasking/883-make-bcast-deliver-to…
Browse files Browse the repository at this point in the history
…-sender

883: Make broadcast deliver to sender (for default group)
  • Loading branch information
lifflander authored Feb 23, 2021
2 parents 164d001 + 73b7e8c commit b5edc11
Show file tree
Hide file tree
Showing 38 changed files with 158 additions and 144 deletions.
5 changes: 1 addition & 4 deletions examples/collection/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,6 @@ struct Block : vt::Collection<Block, vt::Index1D> {
auto proxy = this->getCollectionProxy();
auto proxy_msg = vt::makeMessage<ProxyMsg>(proxy.getProxy());
vt::theMsg()->broadcastMsg<SetupGroup,ProxyMsg>(proxy_msg);
// Invoke it locally: broadcast sends to all other nodes
auto proxy_msg_local = vt::makeMessage<ProxyMsg>(proxy.getProxy());
SetupGroup()(proxy_msg_local.get());
}
}

Expand Down Expand Up @@ -314,7 +311,7 @@ static void solveGroupSetup(vt::NodeType this_node, vt::VirtualProxyType coll_pr

vt::theGroup()->newGroupCollective(
is_even_node, [=](vt::GroupType group_id){
fmt::print("Group is created: id={:x}\n", group_id);
fmt::print("{}: Group is created: id={:x}\n", this_node, group_id);
if (this_node == 1) {
auto msg = vt::makeMessage<SubSolveMsg>(coll_proxy);
vt::envelopeSetGroup(msg->env, group_id);
Expand Down
1 change: 0 additions & 1 deletion src/vt/collective/barrier/barrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ void Barrier::barrierUp(
"barrierDown: barrier={}\n", barrier
);
theMsg()->broadcastMsg<BarrierMsg, barrierDown>(msg);
barrierDown(is_named, is_wait, barrier);
}
}
}
Expand Down
11 changes: 8 additions & 3 deletions src/vt/group/global/group_default.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ namespace vt { namespace group { namespace global {

/*static*/ EventType DefaultGroup::broadcast(
MsgSharedPtr<BaseMsgType> const& base, NodeType const& from,
MsgSizeType const& size, bool const is_root
MsgSizeType const& size, bool const is_root, bool* const deliver
) {
// By default use the default_group_->spanning_tree_
auto const& msg = base.get();
Expand All @@ -159,7 +159,8 @@ namespace vt { namespace group { namespace global {
auto const& num_children = default_group_->spanning_tree_->getNumChildren();
auto const& node = theContext()->getNode();
NodeType const& root_node = 0;
bool const& send_to_root = is_root && node != root_node;
auto const is_root_of_tree = node == root_node;
bool const& send_to_root = is_root && !is_root_of_tree;
EventType event = no_event;

vt_debug_print(
Expand All @@ -168,7 +169,7 @@ namespace vt { namespace group { namespace global {
print_ptr(base.get()), size, from, dest, print_bool(is_root)
);

if (num_children > 0 || send_to_root) {
if (is_root || ((num_children > 0) && !is_root_of_tree) || send_to_root) {
auto const& send_tag = static_cast<messaging::MPI_TagType>(
messaging::MPITag::ActiveMsgTag
);
Expand All @@ -191,6 +192,10 @@ namespace vt { namespace group { namespace global {
});
}

if (is_root && envelopeGetDeliverBcast(msg->env)) {
*deliver = true;
}

// If not the root of the spanning tree, send to the root to propagate to
// the rest of the tree
if (send_to_root) {
Expand Down
2 changes: 1 addition & 1 deletion src/vt/group/global/group_default.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ struct DefaultGroup {
public:
static EventType broadcast(
MsgSharedPtr<BaseMsgType> const& base, NodeType const& from,
MsgSizeType const& size, bool const is_root
MsgSizeType const& size, bool const is_root, bool* const deliver
);

private:
Expand Down
2 changes: 1 addition & 1 deletion src/vt/group/group_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ void GroupManager::initializeLocalGroup(
// Deliver the message normally if it's not a the root of a broadcast
*deliver = !root;
if (group == default_group) {
return global::DefaultGroup::broadcast(base,from,size,root);
return global::DefaultGroup::broadcast(base,from,size,root,deliver);
} else {
auto const& is_collective_group = GroupIDBuilder::isCollective(group);
if (is_collective_group) {
Expand Down
14 changes: 6 additions & 8 deletions src/vt/messaging/active.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ EventType ActiveMessenger::sendMsgBytesWithPut(
auto const& is_term = envelopeIsTerm(msg->env);
auto const& is_put = envelopeIsPut(msg->env);
auto const& is_put_packed = envelopeIsPackedPutType(msg->env);
auto const& is_bcast = envelopeIsBcast(msg->env);

if (!is_term || vt_check_enabled(print_term_msgs)) {
vt_debug_print(
Expand All @@ -195,6 +196,11 @@ EventType ActiveMessenger::sendMsgBytesWithPut(
);
}

vtWarnIf(
!(dest != theContext()->getNode() || is_bcast),
"Destination {} should != this node"
);

MsgSizeType new_msg_size = msg_size;

if (is_put && !is_put_packed) {
Expand Down Expand Up @@ -388,10 +394,6 @@ EventType ActiveMessenger::sendMsgBytes(
);
}

vtWarnIf(
!(dest != theContext()->getNode() || is_bcast),
"Destination {} should != this node"
);
vtAbortIf(
dest >= theContext()->getNumNodes() || dest < 0, "Invalid destination: {}"
);
Expand Down Expand Up @@ -520,10 +522,6 @@ SendInfo ActiveMessenger::sendData(
data_ptr, num_bytes, dest, tag, send_tag
);

vtWarnIf(
dest == theContext()->getNode(),
"Destination {} should != this node"
);
vtAbortIf(
dest >= theContext()->getNumNodes() || dest < 0,
"Invalid destination: {}"
Expand Down
10 changes: 10 additions & 0 deletions src/vt/messaging/active.h
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
*
* \param[in] msg the message to broadcast
* \param[in] msg_size the size of the message to send
* \param[in] deliver_to_sender whether msg should be delivered to sender
* \param[in] tag the tag to put on the message
*
* \return the \c PendingSend for the sent message
Expand All @@ -674,6 +675,7 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
PendingSendType broadcastMsgSz(
MsgPtrThief<MsgT> msg,
ByteType msg_size,
bool deliver_to_sender = true,
TagType tag = no_tag
);

Expand All @@ -683,13 +685,15 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
* \note Takes ownership of the supplied message.
*
* \param[in] msg the message to broadcast
* \param[in] deliver_to_sender whether msg should be delivered to sender
* \param[in] tag the tag to put on the message
*
* \return the \c PendingSend for the sent message
*/
template <typename MsgT, ActiveTypedFnType<MsgT>* f>
PendingSendType broadcastMsg(
MsgPtrThief<MsgT> msg,
bool deliver_to_sender = true,
TagType tag = no_tag
);

Expand Down Expand Up @@ -819,13 +823,15 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
* \note Takes ownership of the supplied message.
*
* \param[in] msg the message to broadcast
* \param[in] deliver_to_sender whether msg should be delivered to sender
* \param[in] tag the optional tag to put on the message
*
* \return the \c PendingSend for the broadcast
*/
template <ActiveFnType* f, typename MsgT>
PendingSendType broadcastMsg(
MsgPtrThief<MsgT> msg,
bool deliver_to_sender = true,
TagType tag = no_tag
);

Expand Down Expand Up @@ -890,6 +896,7 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
* \note Takes ownership of the supplied message.
*
* \param[in] msg the message to broadcast
* \param[in] deliver_to_sender whether msg should be delivered to sender
* \param[in] tag the optional tag to put on the message
*
* \return the \c PendingSend for the broadcast
Expand All @@ -900,6 +907,7 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
>
PendingSendType broadcastMsg(
MsgPtrThief<MsgT> msg,
bool deliver_to_sender = true,
TagType tag = no_tag
);

Expand Down Expand Up @@ -1023,6 +1031,7 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
*
* \param[in] han the handler to invoke
* \param[in] msg the message to broadcast
* \param[in] deliver_to_sender whether msg should be delivered to sender
* \param[in] tag the optional tag to put on the message
*
* \return the \c PendingSend for the send
Expand All @@ -1031,6 +1040,7 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
PendingSendType broadcastMsg(
HandlerType han,
MsgPtrThief<MsgT> msg,
bool deliver_to_sender = true,
TagType tag = no_tag
);

Expand Down
24 changes: 20 additions & 4 deletions src/vt/messaging/active.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ ActiveMessenger::PendingSendType ActiveMessenger::sendMsgSerializableImpl(

MsgT* msgp = msg.get();
if (dest == broadcast_dest) {
return SerializedMessenger::broadcastSerialMsg<MsgT>(msgp,han);
return SerializedMessenger::broadcastSerialMsg<MsgT>(
msgp, han, envelopeGetDeliverBcast(msgp->env)
);
} else {
return SerializedMessenger::sendSerialMsg<MsgT>(dest,msgp,han);
}
Expand All @@ -172,20 +174,20 @@ ActiveMessenger::PendingSendType ActiveMessenger::sendMsgCopyableImpl(
MsgT* rawMsg = msg.get();

bool is_term = envelopeIsTerm(rawMsg->env);
const bool is_bcast = dest == broadcast_dest;

if (!is_term || vt_check_enabled(print_term_msgs)) {
vt_debug_print(
active, node,
dest == broadcast_dest
is_bcast
? "broadcastMsg of ptr={}, type={}\n"
: "sendMsg of ptr={}, type={}\n",
print_ptr(rawMsg), typeid(MsgT).name()
);
}

if (dest == broadcast_dest) {
if (is_bcast) {
dest = theContext()->getNode();
setBroadcastType(rawMsg->env);
}
if (msg_size == msgsize_not_specified) {
msg_size = sizeof(MsgT);
Expand Down Expand Up @@ -240,10 +242,14 @@ template <typename MsgT, ActiveTypedFnType<MsgT>* f>
ActiveMessenger::PendingSendType ActiveMessenger::broadcastMsgSz(
MsgPtrThief<MsgT> msg,
ByteType msg_size,
bool deliver_to_sender,
TagType tag
) {
auto const han = auto_registry::makeAutoHandler<MsgT,f>();
MsgSharedPtr<MsgT> msgptr = msg.msg_;

setBroadcastType(msgptr->env, deliver_to_sender);

return sendMsgImpl<MsgT>(
broadcast_dest, han, msgptr, msg_size, tag
);
Expand All @@ -252,10 +258,14 @@ ActiveMessenger::PendingSendType ActiveMessenger::broadcastMsgSz(
template <typename MsgT, ActiveTypedFnType<MsgT>* f>
ActiveMessenger::PendingSendType ActiveMessenger::broadcastMsg(
MsgPtrThief<MsgT> msg,
bool deliver_to_sender,
TagType tag
) {
auto const han = auto_registry::makeAutoHandler<MsgT,f>();
MsgSharedPtr<MsgT> msgptr = msg.msg_;

setBroadcastType(msgptr->env, deliver_to_sender);

return sendMsgImpl<MsgT>(
broadcast_dest, han, msgptr, msgsize_not_specified, tag
);
Expand Down Expand Up @@ -310,10 +320,12 @@ ActiveMessenger::PendingSendType ActiveMessenger::broadcastMsgAuto(
template <ActiveFnType* f, typename MsgT>
ActiveMessenger::PendingSendType ActiveMessenger::broadcastMsg(
MsgPtrThief<MsgT> msg,
bool deliver_to_sender,
TagType tag
) {
auto const han = auto_registry::makeAutoHandler<MsgT,f>();
MsgSharedPtr<MsgT> msgptr = msg.msg_;
setBroadcastType(msgptr->env, deliver_to_sender);
return sendMsgImpl<MsgT>(
broadcast_dest, han, msgptr, msgsize_not_specified, tag
);
Expand All @@ -333,10 +345,12 @@ ActiveMessenger::PendingSendType ActiveMessenger::sendMsg(
template <typename FunctorT, typename MsgT>
ActiveMessenger::PendingSendType ActiveMessenger::broadcastMsg(
MsgPtrThief<MsgT> msg,
bool deliver_to_sender,
TagType tag
) {
auto const han = auto_registry::makeAutoHandlerFunctor<FunctorT,true,MsgT*>();
MsgSharedPtr<MsgT> msgptr = msg.msg_;
setBroadcastType(msgptr->env, deliver_to_sender);
return sendMsgImpl<MsgT>(
broadcast_dest, han, msgptr, msgsize_not_specified, tag
);
Expand Down Expand Up @@ -410,9 +424,11 @@ template <typename MsgT>
ActiveMessenger::PendingSendType ActiveMessenger::broadcastMsg(
HandlerType han,
MsgPtrThief<MsgT> msg,
bool deliver_to_sender,
TagType tag
) {
MsgSharedPtr<MsgT> msgptr = msg.msg_;
setBroadcastType(msgptr->env, deliver_to_sender);
return sendMsgImpl<MsgT>(
broadcast_dest, han, msgptr, msgsize_not_specified, tag
);
Expand Down
4 changes: 4 additions & 0 deletions src/vt/messaging/envelope/envelope_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ struct ActiveEnvelope {
/// True iff the message is considered locked.
/// If locked, changes to the envelope will result in failure.
bool is_locked : 1;

// Used only for broadcast to default group
// Determines whether message should also be sent to the sender
bool deliver_bcast_to_sender : 1;
};

}} /* end namespace vt::messaging */
Expand Down
10 changes: 10 additions & 0 deletions src/vt/messaging/envelope/envelope_get.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ inline HandlerType envelopeGetHandler(Env const& env);
template <typename Env>
inline NodeType envelopeGetDest(Env const& env);

/**
* \brief Check whether bcast should be delivered to sender
*
* \param[in] env the envelope
*
* \return whether to deliver msg to sender
*/
template <typename Env>
inline bool envelopeGetDeliverBcast(Env const& env);

/**
* \brief Get the group on an envelope
*
Expand Down
5 changes: 5 additions & 0 deletions src/vt/messaging/envelope/envelope_get.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ inline NodeType envelopeGetDest(Env const& env) {
return reinterpret_cast<Envelope const*>(&env)->dest;
}

template <typename Env>
inline bool envelopeGetDeliverBcast(Env const& env) {
return reinterpret_cast<Envelope const*>(&env)->deliver_bcast_to_sender;
}

template <typename Env>
inline GroupType envelopeGetGroup(Env& env) {
return reinterpret_cast<Envelope*>(&env)->group;
Expand Down
3 changes: 2 additions & 1 deletion src/vt/messaging/envelope/envelope_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ inline void setTermType(Env& env);
* \brief Set broadcast bit (changes how \c dest is interpreted) \c EnvBroadcast
*
* \param[in,out] env the envelope
* \param[in] deliver_to_sender whether the message should be delivered to sender
*/
template <typename Env>
inline void setBroadcastType(Env& env);
inline void setBroadcastType(Env& env, bool deliver_to_sender = true);

/**
* \brief Set epoch bit \c EnvEpoch
Expand Down
3 changes: 2 additions & 1 deletion src/vt/messaging/envelope/envelope_set.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@ inline void setTermType(Env& env) {
}

template <typename Env>
inline void setBroadcastType(Env& env) {
inline void setBroadcastType(Env& env, bool deliver_to_sender) {
vtAssert(not envelopeIsLocked(env), "Envelope locked.");
reinterpret_cast<Envelope*>(&env)->type |= 1 << eEnvType::EnvBroadcast;
reinterpret_cast<Envelope*>(&env)->deliver_bcast_to_sender = deliver_to_sender;
}

template <typename Env>
Expand Down
9 changes: 9 additions & 0 deletions src/vt/messaging/envelope/envelope_setup.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ inline void envelopeInitEmpty(Envelope& env);
template <typename Env>
inline void envelopeInitCopy(Env& env, Env const& src_env);

/**
* \brief Copy broadcast related data
*
* \param[in,out] env the target envelope to copy to
* \param[in] env the original envelope to use as a copy
*/
template <typename Env>
inline void envelopeCopyBcastData(Env& env, Env const& src_env);

/**
* \brief Initialize/validate an envelope that has been received.
*
Expand Down
Loading

0 comments on commit b5edc11

Please sign in to comment.