From b3edb92c13e343a3ed09ee9da3bb7dc47af41f26 Mon Sep 17 00:00:00 2001 From: Jonathan Lifflander Date: Wed, 15 Mar 2023 16:23:04 -0700 Subject: [PATCH] #2102: collection: finish the final fix for this bug --- src/vt/registry/auto/auto_registry_common.h | 37 ++++++++++++- src/vt/vrt/collection/manager.h | 61 +-------------------- src/vt/vrt/collection/manager.impl.h | 42 ++------------ src/vt/vrt/collection/types/base.h | 1 + tests/unit/collection/test_promote.cc | 5 +- 5 files changed, 48 insertions(+), 98 deletions(-) diff --git a/src/vt/registry/auto/auto_registry_common.h b/src/vt/registry/auto/auto_registry_common.h index f309021a77..75076fec5d 100644 --- a/src/vt/registry/auto/auto_registry_common.h +++ b/src/vt/registry/auto/auto_registry_common.h @@ -58,6 +58,13 @@ #include #include +namespace vt::vrt::collection { + +template +struct ColMsgWrap; + +} /* end namespace vt::vrt::collection */ + namespace vt { namespace auto_registry { struct SentinelObject {}; @@ -74,6 +81,18 @@ struct HandlersDispatcher final : BaseHandlersDispatcher { explicit HandlersDispatcher(HandlerT in_fn_ptr) : fp(in_fn_ptr) { } + template + using ColMsgTrait = typename U::IsCollectionMessage; + template + using IsColMsgTrait = + detection::is_detected_convertible; + + template + using ColTrait = typename U::IsCollectionType; + template + using IsColTrait = + detection::is_detected_convertible; + public: void dispatch(messaging::BaseMsg* base_msg, void* object) const override { using T = HandlerT; @@ -86,9 +105,23 @@ struct HandlersDispatcher final : BaseHandlersDispatcher { } else if constexpr (std::is_same_v*>) { fp(msg); } else if constexpr (std::is_same_v) { - fp(elm, msg); + if constexpr (IsColMsgTrait::value or not IsColTrait::value) { + fp(elm, msg); + } else { + auto wrap_msg = static_cast< + vrt::collection::ColMsgWrap* + >(base_msg); + fp(elm, &wrap_msg->getMsg()); + } } else if constexpr (std::is_same_v) { - (elm->*fp)(msg); + if constexpr (IsColMsgTrait::value or not IsColTrait::value) { + (elm->*fp)(msg); + } else { + auto wrap_msg = static_cast< + vrt::collection::ColMsgWrap* + >(base_msg); + (elm->*fp)(&wrap_msg->getMsg()); + } } else if constexpr (std::is_same_v) { std::apply(fp, msg->getTuple()); } else { diff --git a/src/vt/vrt/collection/manager.h b/src/vt/vrt/collection/manager.h index dc3984ea96..3dd95b156d 100644 --- a/src/vt/vrt/collection/manager.h +++ b/src/vt/vrt/collection/manager.h @@ -135,15 +135,6 @@ struct CollectionManager std::is_default_constructible::value, CollectionProxyWrapType >; - template - using IsWrapType = std::enable_if_t< - std::is_same>::value,U - >; - template - using IsNotWrapType = std::enable_if_t< - !std::is_same>::value,U - >; - /** * \internal \brief System call to construct a collection manager */ @@ -625,36 +616,6 @@ struct CollectionManager VirtualElmProxyType const& proxy, MsgT *msg ); - /** - * \internal \brief Deliver a message to a collection element with a promoted - * collection message that wrapped the user's non-collection message. - * - * \param[in] msg the message - * \param[in] col pointer to collection element - * \param[in] han the handler to invoke - * \param[in] from node that sent the message - */ - template - static IsWrapType collectionMsgDeliver( - MsgT* msg, CollectionBase* col, HandlerType han, - NodeType from - ); - - /** - * \internal \brief Deliver a message to a collection element with a normal - * collection message - * - * \param[in] msg the message - * \param[in] col pointer to collection element - * \param[in] han the handler to invoke - * \param[in] from node that sent the message - */ - template - static IsNotWrapType collectionMsgDeliver( - MsgT* msg, CollectionBase* col, HandlerType han, - NodeType from - ); - /** * \internal \brief Base collection message handler * @@ -1152,7 +1113,7 @@ struct CollectionManager public: /** - * \internal \brief Deliver a promoted/wrapped message to a collection element + * \internal \brief Deliver a message to a collection element * * \param[in] msg the message * \param[in] col the collection element pointer @@ -1160,24 +1121,8 @@ struct CollectionManager * \param[in] from the node that sent it * \param[in] event the associated trace event */ - template - static IsWrapType collectionAutoMsgDeliver( - MsgT* msg, Indexable* col, HandlerType han, - NodeType from, trace::TraceEventIDType event, bool immediate - ); - - /** - * \internal \brief Deliver a regular collection message to a collection - * element - * - * \param[in] msg the message - * \param[in] col the collection element pointer - * \param[in] han the handler to invoke - * \param[in] from the node that sent it - * \param[in] event the associated trace event - */ - template - static IsNotWrapType collectionAutoMsgDeliver( + template + static void collectionAutoMsgDeliver( MsgT* msg, Indexable* col, HandlerType han, NodeType from, trace::TraceEventIDType event, bool immediate ); diff --git a/src/vt/vrt/collection/manager.impl.h b/src/vt/vrt/collection/manager.impl.h index b3673d82fd..3a1a965fc6 100644 --- a/src/vt/vrt/collection/manager.impl.h +++ b/src/vt/vrt/collection/manager.impl.h @@ -203,39 +203,8 @@ GroupType CollectionManager::createGroupCollection( return group_id; } -template -/*static*/ CollectionManager::IsWrapType -CollectionManager::collectionAutoMsgDeliver( - MsgT* msg, Indexable* base, HandlerType han, NodeType from, - trace::TraceEventIDType event, bool immediate -) { - // Reference because it's a inner message that should *never* be deallocated - messageRef(&msg->getMsg()); - MsgSharedPtr user_msg{&msg->getMsg()}; - - // Expand out the index for tracing purposes; Projections takes up to - // 4-dimensions -#if vt_check_enabled(trace_enabled) - auto idx = base->getIndex(); - uint64_t const idx1 = idx.ndims() > 0 ? idx[0] : 0; - uint64_t const idx2 = idx.ndims() > 1 ? idx[1] : 0; - uint64_t const idx3 = idx.ndims() > 2 ? idx[2] : 0; - uint64_t const idx4 = idx.ndims() > 3 ? idx[3] : 0; -#endif - - runnable::makeRunnable(user_msg, true, han, from) - .withTDEpoch(theMsg()->getEpochContextMsg(msg)) - .withCollection(base) -#if vt_check_enabled(trace_enabled) - .withTraceIndex(event, idx1, idx2, idx3, idx4) -#endif - .withLBData(base, msg) - .runOrEnqueue(immediate); -} - -template -/*static*/ CollectionManager::IsNotWrapType -CollectionManager::collectionAutoMsgDeliver( +template +/*static*/ void CollectionManager::collectionAutoMsgDeliver( MsgT* msg, Indexable* base, HandlerType han, NodeType from, trace::TraceEventIDType event, bool immediate ) { @@ -250,6 +219,7 @@ CollectionManager::collectionAutoMsgDeliver( #endif auto m = promoteMsg(msg); + runnable::makeRunnable(m, true, han, from) .withTDEpoch(theMsg()->getEpochContextMsg(msg)) .withCollection(base) @@ -294,7 +264,7 @@ template #if vt_check_enabled(trace_enabled) trace_event = col_msg->getFromTraceEvent(); #endif - collectionAutoMsgDeliver( + collectionAutoMsgDeliver( msg, base, handler, from, trace_event, false ); }); @@ -357,7 +327,7 @@ template #if vt_check_enabled(trace_enabled) trace_event = col_msg->getFromTraceEvent(); #endif - collectionAutoMsgDeliver( + collectionAutoMsgDeliver( msg, col_ptr, sub_handler, from, trace_event, false ); theMsg()->popEpoch(cur_epoch); @@ -593,7 +563,7 @@ void CollectionManager::invokeMsgImpl( trace_event = theMsg()->makeTraceCreationSend(han, msg_size, is_bcast); #endif - collectionAutoMsgDeliver( + collectionAutoMsgDeliver( msg.get(), col_ptr, han, from, trace_event, true ); diff --git a/src/vt/vrt/collection/types/base.h b/src/vt/vrt/collection/types/base.h index bd73d99515..263925cada 100644 --- a/src/vt/vrt/collection/types/base.h +++ b/src/vt/vrt/collection/types/base.h @@ -58,6 +58,7 @@ namespace vt { namespace vrt { namespace collection { template struct CollectionBase : Indexable { + using IsCollectionType = std::true_type; using ProxyType = VirtualElmProxyType; using CollectionProxyType = CollectionProxy; diff --git a/tests/unit/collection/test_promote.cc b/tests/unit/collection/test_promote.cc index 3680110692..5c29f56bed 100644 --- a/tests/unit/collection/test_promote.cc +++ b/tests/unit/collection/test_promote.cc @@ -83,8 +83,9 @@ TEST_F(TestCollectionPromoteMsg, test_collection_promote_1) { auto proxy = theCollection()->constructCollective( num_elems, "test_collection_promote_1" ); - proxy.broadcast("hello there"); - + if (this_node == 0) { + proxy.broadcast("hello there"); + } } }}}} // end namespace vt::tests::unit::invoke