Skip to content

Commit

Permalink
#2102: collection: finish the final fix for this bug
Browse files Browse the repository at this point in the history
  • Loading branch information
lifflander authored and cz4rs committed Mar 27, 2023
1 parent b8bbcec commit b3edb92
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 98 deletions.
37 changes: 35 additions & 2 deletions src/vt/registry/auto/auto_registry_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@
#include <cstdlib>
#include <functional>

namespace vt::vrt::collection {

template <typename ColT, typename UserMsgT, typename BaseMsgT>
struct ColMsgWrap;

} /* end namespace vt::vrt::collection */

namespace vt { namespace auto_registry {

struct SentinelObject {};
Expand All @@ -74,6 +81,18 @@ struct HandlersDispatcher final : BaseHandlersDispatcher {

explicit HandlersDispatcher(HandlerT in_fn_ptr) : fp(in_fn_ptr) { }

template <typename U>
using ColMsgTrait = typename U::IsCollectionMessage;
template <typename U>
using IsColMsgTrait =
detection::is_detected_convertible<std::true_type, ColMsgTrait, U>;

template <typename U>
using ColTrait = typename U::IsCollectionType;
template <typename U>
using IsColTrait =
detection::is_detected_convertible<std::true_type, ColTrait, U>;

public:
void dispatch(messaging::BaseMsg* base_msg, void* object) const override {
using T = HandlerT;
Expand All @@ -86,9 +105,23 @@ struct HandlersDispatcher final : BaseHandlersDispatcher {
} else if constexpr (std::is_same_v<T, ActiveTypedFnType<MsgT>*>) {
fp(msg);
} else if constexpr (std::is_same_v<T, ColTypedFnType*>) {
fp(elm, msg);
if constexpr (IsColMsgTrait<MsgT>::value or not IsColTrait<ObjT>::value) {
fp(elm, msg);
} else {
auto wrap_msg = static_cast<
vrt::collection::ColMsgWrap<ObjT, MsgT, vt::Message>*
>(base_msg);
fp(elm, &wrap_msg->getMsg());
}
} else if constexpr (std::is_same_v<T, ColMemberTypedFnType>) {
(elm->*fp)(msg);
if constexpr (IsColMsgTrait<MsgT>::value or not IsColTrait<ObjT>::value) {
(elm->*fp)(msg);
} else {
auto wrap_msg = static_cast<
vrt::collection::ColMsgWrap<ObjT, MsgT, vt::Message>*
>(base_msg);
(elm->*fp)(&wrap_msg->getMsg());
}
} else if constexpr (std::is_same_v<ObjT, SentinelObject>) {
std::apply(fp, msg->getTuple());
} else {
Expand Down
61 changes: 3 additions & 58 deletions src/vt/vrt/collection/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,6 @@ struct CollectionManager
std::is_default_constructible<ColT>::value, CollectionProxyWrapType<ColT, IndexT>
>;

template <typename ColT, typename UserMsgT, typename T, typename U=void>
using IsWrapType = std::enable_if_t<
std::is_same<T,ColMsgWrap<ColT,UserMsgT>>::value,U
>;
template <typename ColT, typename UserMsgT, typename T, typename U=void>
using IsNotWrapType = std::enable_if_t<
!std::is_same<T,ColMsgWrap<ColT,UserMsgT>>::value,U
>;

/**
* \internal \brief System call to construct a collection manager
*/
Expand Down Expand Up @@ -625,36 +616,6 @@ struct CollectionManager
VirtualElmProxyType<ColT> const& proxy, MsgT *msg
);

/**
* \internal \brief Deliver a message to a collection element with a promoted
* collection message that wrapped the user's non-collection message.
*
* \param[in] msg the message
* \param[in] col pointer to collection element
* \param[in] han the handler to invoke
* \param[in] from node that sent the message
*/
template <typename ColT, typename IndexT, typename MsgT, typename UserMsgT>
static IsWrapType<ColT, UserMsgT, MsgT> collectionMsgDeliver(
MsgT* msg, CollectionBase<ColT, IndexT>* col, HandlerType han,
NodeType from
);

/**
* \internal \brief Deliver a message to a collection element with a normal
* collection message
*
* \param[in] msg the message
* \param[in] col pointer to collection element
* \param[in] han the handler to invoke
* \param[in] from node that sent the message
*/
template <typename ColT, typename IndexT, typename MsgT, typename UserMsgT>
static IsNotWrapType<ColT, UserMsgT, MsgT> collectionMsgDeliver(
MsgT* msg, CollectionBase<ColT, IndexT>* col, HandlerType han,
NodeType from
);

/**
* \internal \brief Base collection message handler
*
Expand Down Expand Up @@ -1152,32 +1113,16 @@ struct CollectionManager

public:
/**
* \internal \brief Deliver a promoted/wrapped message to a collection element
* \internal \brief Deliver a message to a collection element
*
* \param[in] msg the message
* \param[in] col the collection element pointer
* \param[in] han the handler to invoke
* \param[in] from the node that sent it
* \param[in] event the associated trace event
*/
template <typename ColT, typename IndexT, typename MsgT, typename UserMsgT>
static IsWrapType<ColT, UserMsgT, MsgT> collectionAutoMsgDeliver(
MsgT* msg, Indexable<IndexT>* col, HandlerType han,
NodeType from, trace::TraceEventIDType event, bool immediate
);

/**
* \internal \brief Deliver a regular collection message to a collection
* element
*
* \param[in] msg the message
* \param[in] col the collection element pointer
* \param[in] han the handler to invoke
* \param[in] from the node that sent it
* \param[in] event the associated trace event
*/
template <typename ColT, typename IndexT, typename MsgT, typename UserMsgT>
static IsNotWrapType<ColT, UserMsgT, MsgT> collectionAutoMsgDeliver(
template <typename ColT, typename IndexT, typename MsgT>
static void collectionAutoMsgDeliver(
MsgT* msg, Indexable<IndexT>* col, HandlerType han,
NodeType from, trace::TraceEventIDType event, bool immediate
);
Expand Down
42 changes: 6 additions & 36 deletions src/vt/vrt/collection/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,39 +203,8 @@ GroupType CollectionManager::createGroupCollection(
return group_id;
}

template <typename ColT, typename IndexT, typename MsgT, typename UserMsgT>
/*static*/ CollectionManager::IsWrapType<ColT, UserMsgT, MsgT>
CollectionManager::collectionAutoMsgDeliver(
MsgT* msg, Indexable<IndexT>* base, HandlerType han, NodeType from,
trace::TraceEventIDType event, bool immediate
) {
// Reference because it's a inner message that should *never* be deallocated
messageRef(&msg->getMsg());
MsgSharedPtr<UserMsgT> user_msg{&msg->getMsg()};

// Expand out the index for tracing purposes; Projections takes up to
// 4-dimensions
#if vt_check_enabled(trace_enabled)
auto idx = base->getIndex();
uint64_t const idx1 = idx.ndims() > 0 ? idx[0] : 0;
uint64_t const idx2 = idx.ndims() > 1 ? idx[1] : 0;
uint64_t const idx3 = idx.ndims() > 2 ? idx[2] : 0;
uint64_t const idx4 = idx.ndims() > 3 ? idx[3] : 0;
#endif

runnable::makeRunnable(user_msg, true, han, from)
.withTDEpoch(theMsg()->getEpochContextMsg(msg))
.withCollection(base)
#if vt_check_enabled(trace_enabled)
.withTraceIndex(event, idx1, idx2, idx3, idx4)
#endif
.withLBData(base, msg)
.runOrEnqueue(immediate);
}

template <typename ColT, typename IndexT, typename MsgT, typename UserMsgT>
/*static*/ CollectionManager::IsNotWrapType<ColT, UserMsgT, MsgT>
CollectionManager::collectionAutoMsgDeliver(
template <typename ColT, typename IndexT, typename MsgT>
/*static*/ void CollectionManager::collectionAutoMsgDeliver(
MsgT* msg, Indexable<IndexT>* base, HandlerType han, NodeType from,
trace::TraceEventIDType event, bool immediate
) {
Expand All @@ -250,6 +219,7 @@ CollectionManager::collectionAutoMsgDeliver(
#endif

auto m = promoteMsg(msg);

runnable::makeRunnable(m, true, han, from)
.withTDEpoch(theMsg()->getEpochContextMsg(msg))
.withCollection(base)
Expand Down Expand Up @@ -294,7 +264,7 @@ template <typename ColT, typename IndexT, typename MsgT>
#if vt_check_enabled(trace_enabled)
trace_event = col_msg->getFromTraceEvent();
#endif
collectionAutoMsgDeliver<ColT,IndexT,MsgT,typename MsgT::UserMsgType>(
collectionAutoMsgDeliver<ColT,IndexT,MsgT>(
msg, base, handler, from, trace_event, false
);
});
Expand Down Expand Up @@ -357,7 +327,7 @@ template <typename ColT, typename IndexT, typename MsgT>
#if vt_check_enabled(trace_enabled)
trace_event = col_msg->getFromTraceEvent();
#endif
collectionAutoMsgDeliver<ColT,IndexT,MsgT,typename MsgT::UserMsgType>(
collectionAutoMsgDeliver<ColT,IndexT,MsgT>(
msg, col_ptr, sub_handler, from, trace_event, false
);
theMsg()->popEpoch(cur_epoch);
Expand Down Expand Up @@ -593,7 +563,7 @@ void CollectionManager::invokeMsgImpl(
trace_event = theMsg()->makeTraceCreationSend(han, msg_size, is_bcast);
#endif

collectionAutoMsgDeliver<ColT, IndexT, MsgT, typename MsgT::UserMsgType>(
collectionAutoMsgDeliver<ColT, IndexT, MsgT>(
msg.get(), col_ptr, han, from, trace_event, true
);

Expand Down
1 change: 1 addition & 0 deletions src/vt/vrt/collection/types/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ namespace vt { namespace vrt { namespace collection {

template <typename ColT, typename IndexT>
struct CollectionBase : Indexable<IndexT> {
using IsCollectionType = std::true_type;
using ProxyType = VirtualElmProxyType<ColT, IndexT>;
using CollectionProxyType = CollectionProxy<ColT, IndexT>;

Expand Down
5 changes: 3 additions & 2 deletions tests/unit/collection/test_promote.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ TEST_F(TestCollectionPromoteMsg, test_collection_promote_1) {
auto proxy = theCollection()->constructCollective<Hello>(
num_elems, "test_collection_promote_1"
);
proxy.broadcast<TestMsg,&Hello::doWork>("hello there");

if (this_node == 0) {
proxy.broadcast<TestMsg,&Hello::doWork>("hello there");
}
}

}}}} // end namespace vt::tests::unit::invoke

0 comments on commit b3edb92

Please sign in to comment.