diff --git a/src/vt/messaging/active.cc b/src/vt/messaging/active.cc index 94542ecabe..80431fe4ea 100644 --- a/src/vt/messaging/active.cc +++ b/src/vt/messaging/active.cc @@ -498,10 +498,16 @@ EventType ActiveMessenger::doMessageSend( } else { recordLBDataCommForSend(dest, base, base.size()); - runnable::makeRunnable(base, true, envelopeGetHandler(msg->env), dest) - .withTDEpochFromMsg(is_term) - .withLBData(&bare_handler_lb_data_, bare_handler_dummy_elm_id_for_lb_data_) - .enqueue(); + auto han = envelopeGetHandler(msg->env); + bool const is_obj = HandlerManagerType::isHandlerObjGroup(han); + if (is_obj) { + objgroup::dispatchObjGroup(base, han, dest, nullptr); + } else { + runnable::makeRunnable(base, true, envelopeGetHandler(msg->env), dest) + .withTDEpochFromMsg(is_term) + .withLBData(&bare_handler_lb_data_, bare_handler_dummy_elm_id_for_lb_data_) + .enqueue(); + } } return no_event; } @@ -957,8 +963,7 @@ void ActiveMessenger::prepareActiveMsgToRun( bool const is_obj = HandlerManagerType::isHandlerObjGroup(handler); if (is_obj) { - vtAbortIf(cont != nullptr, "Must be nullptr"); - objgroup::dispatchObjGroup(base, handler); + objgroup::dispatchObjGroup(base, handler, from_node, cont); } else { runnable::makeRunnable(base, not is_term, handler, from_node) .withContinuation(cont) diff --git a/src/vt/objgroup/dispatch/dispatch.h b/src/vt/objgroup/dispatch/dispatch.h index 417640fe50..c2fae7b93b 100644 --- a/src/vt/objgroup/dispatch/dispatch.h +++ b/src/vt/objgroup/dispatch/dispatch.h @@ -63,7 +63,9 @@ struct Dispatch final : DispatchBase { virtual ~Dispatch() = default; - void run(HandlerType han, BaseMessage* msg) override; + void run( + HandlerType han, BaseMessage* msg, NodeType from_node, ActionType cont + ) override; void* objPtr() const override { return obj_; } private: diff --git a/src/vt/objgroup/dispatch/dispatch.impl.h b/src/vt/objgroup/dispatch/dispatch.impl.h index 4623cad82f..e8b00c006e 100644 --- a/src/vt/objgroup/dispatch/dispatch.impl.h +++ b/src/vt/objgroup/dispatch/dispatch.impl.h @@ -51,13 +51,15 @@ namespace vt { namespace objgroup { namespace dispatch { template -void Dispatch::run(HandlerType han, BaseMessage* msg) { - //using ActiveFnType = void(ObjT::*)(vt::BaseMessage*); +void Dispatch::run( + HandlerType han, BaseMessage* msg, NodeType from_node, ActionType cont +) { vtAssert(obj_ != nullptr, "Must have a valid object"); auto tmsg = static_cast(msg); auto m = promoteMsg(tmsg); - runnable::makeRunnable(m, true, han, theContext()->getNode()) + runnable::makeRunnable(m, true, han, from_node) + .withContinuation(cont) .withObjGroup(obj_) .withTDEpochFromMsg() .enqueue(); diff --git a/src/vt/objgroup/dispatch/dispatch_base.h b/src/vt/objgroup/dispatch/dispatch_base.h index 42bc9b63ab..142a55d0e9 100644 --- a/src/vt/objgroup/dispatch/dispatch_base.h +++ b/src/vt/objgroup/dispatch/dispatch_base.h @@ -67,7 +67,9 @@ struct DispatchBase { * Dispatch to the handler; the base is closed around the proper object * pointer that is type-erased here */ - virtual void run(HandlerType han, BaseMessage* msg) = 0; + virtual void run( + HandlerType han, BaseMessage* msg, NodeType from_node, ActionType cont + ) = 0; virtual void* objPtr() const = 0; ObjGroupProxyType proxy() const { return proxy_; } diff --git a/src/vt/objgroup/manager.cc b/src/vt/objgroup/manager.cc index a3932db0bb..35109cef1f 100644 --- a/src/vt/objgroup/manager.cc +++ b/src/vt/objgroup/manager.cc @@ -83,7 +83,10 @@ ObjGroupProxyType ObjGroupManager::getProxy(ObjGroupProxyType proxy) { return proxy; } -void ObjGroupManager::dispatch(MsgSharedPtr msg, HandlerType han) { +void ObjGroupManager::dispatch( + MsgSharedPtr msg, HandlerType han, NodeType from_node, + ActionType cont +) { // Extract the control-bit sequence from the handler auto const ctrl = HandlerManager::getHandlerControl(han); vt_debug_print( @@ -103,9 +106,9 @@ void ObjGroupManager::dispatch(MsgSharedPtr msg, HandlerType han) if (epoch != no_epoch and epoch != term::any_epoch_sentinel) { theTerm()->produce(epoch); } - pending_[proxy].push_back(msg); + pending_[proxy].emplace_back(msg, from_node, cont, han); } else { - dispatch_iter->second->run(han,msg.get()); + dispatch_iter->second->run(han, msg.get(), from_node, cont); } } @@ -163,12 +166,15 @@ elm::ElementIDStruct ObjGroupManager::getNextElm(ObjGroupProxyType proxy) { } } -void dispatchObjGroup(MsgSharedPtr msg, HandlerType han) { +void dispatchObjGroup( + MsgSharedPtr msg, HandlerType han, NodeType from_node, + ActionType cont +) { vt_debug_print( verbose, objgroup, "dispatchObjGroup: han={:x}\n", han ); - return theObjGroup()->dispatch(msg,han); + return theObjGroup()->dispatch(msg, han, from_node, cont); } }} /* end namespace vt::objgroup */ diff --git a/src/vt/objgroup/manager.fwd.h b/src/vt/objgroup/manager.fwd.h index 9b67f7e5d2..2bd37cfdee 100644 --- a/src/vt/objgroup/manager.fwd.h +++ b/src/vt/objgroup/manager.fwd.h @@ -60,7 +60,10 @@ namespace detail { holder::HolderBase* getHolderBase(HandlerType handler); } /* end namespace detail */ -void dispatchObjGroup(MsgSharedPtr msg, HandlerType han); +void dispatchObjGroup( + MsgSharedPtr msg, HandlerType han, NodeType from_node, + ActionType cont +); template messaging::PendingSend send(MsgSharedPtr msg, HandlerType han, NodeType node); diff --git a/src/vt/objgroup/manager.h b/src/vt/objgroup/manager.h index a43c107501..65c7aa31c7 100644 --- a/src/vt/objgroup/manager.h +++ b/src/vt/objgroup/manager.h @@ -91,9 +91,29 @@ struct ObjGroupManager : runtime::component::Component { using HolderBasePtrType = std::unique_ptr; using DispatchBaseType = dispatch::DispatchBase; using DispatchBasePtrType = std::unique_ptr; - using MsgContainerType = std::vector>; using PendingSendType = messaging::PendingSend; +private: + struct PendingRecv { + PendingRecv( + MsgSharedPtr in_msg, NodeType in_from_node, + ActionType in_cont, HandlerType in_han + ) : msg_(in_msg), + from_node_(in_from_node), + cont_(in_cont), + han_(in_han) + { } + + MsgSharedPtr msg_; + NodeType from_node_ = uninitialized_destination; + ActionType cont_ = nullptr; + HandlerType han_ = uninitialized_handler; + }; + +public: + using MsgContainerType = std::vector; + + /** * \internal \brief Construct the ObjGroupManager */ @@ -338,8 +358,13 @@ struct ObjGroupManager : runtime::component::Component { * * \param[in] msg the message * \param[in] han the handler to invoke + * \param[in] from_node the node it was from + * \param[in] cont optional continuation to execute after */ - void dispatch(MsgSharedPtr msg, HandlerType han); + void dispatch( + MsgSharedPtr msg, HandlerType han, NodeType from_node, + ActionType cont + ); /** * \internal \brief Send a message to an objgroup diff --git a/src/vt/objgroup/manager.impl.h b/src/vt/objgroup/manager.impl.h index dffb5f7917..75637bfbba 100644 --- a/src/vt/objgroup/manager.impl.h +++ b/src/vt/objgroup/manager.impl.h @@ -167,15 +167,13 @@ void ObjGroupManager::regObjProxy(ObjT* obj, ObjGroupProxyType proxy) { ); auto pending_iter = pending_.find(proxy); if (pending_iter != pending_.end()) { - for (auto&& msg : pending_iter->second) { - theSched()->enqueue([msg]{ - auto const handler = envelopeGetHandler(msg->env); - auto const epoch = envelopeGetEpoch(msg->env); - theObjGroup()->dispatch(msg,handler); - if (epoch != no_epoch) { - theTerm()->consume(epoch); - } - }); + for (auto&& pending : pending_iter->second) { + auto const& msg = pending.msg_; + auto const epoch = envelopeGetEpoch(msg->env); + dispatch(msg, pending.han_, pending.from_node_, pending.cont_); + if (epoch != no_epoch) { + theTerm()->consume(epoch); + } } pending_.erase(pending_iter); } diff --git a/src/vt/objgroup/manager.static.h b/src/vt/objgroup/manager.static.h index faf6aaa5b3..8fb3df3696 100644 --- a/src/vt/objgroup/manager.static.h +++ b/src/vt/objgroup/manager.static.h @@ -66,9 +66,11 @@ messaging::PendingSend send(MsgSharedPtr msg, HandlerType han, NodeType de return messaging::PendingSend{cur_epoch, [msg, han, cur_epoch, this_node](){ auto holder = detail::getHolderBase(han); auto const& elm_id = holder->getElmID(); + auto elm = holder->getPtr(); auto lb_data = &holder->getLBData(); runnable::makeRunnable(msg, true, han, this_node) + .withObjGroup(elm) .withTDEpoch(cur_epoch) .withLBData(lb_data, elm_id) .enqueue(); @@ -89,8 +91,14 @@ void invoke(messaging::MsgPtrThief msg, HandlerType han, NodeType dest_nod ); // this is a local invocation.. no thread required + auto holder = detail::getHolderBase(han); + auto const& elm_id = holder->getElmID(); + auto elm = holder->getPtr(); + auto lb_data = &holder->getLBData(); runnable::makeRunnable(msg.msg_, false, han, this_node) + .withObjGroup(elm) .withTDEpochFromMsg() + .withLBData(lb_data, elm_id) .run(); } diff --git a/src/vt/runnable/make_runnable.h b/src/vt/runnable/make_runnable.h index 7113ac9080..795f4bbfc2 100644 --- a/src/vt/runnable/make_runnable.h +++ b/src/vt/runnable/make_runnable.h @@ -96,7 +96,9 @@ struct RunnableMaker { * \param[in] cont the continuation */ RunnableMaker&& withContinuation(ActionType cont) { - impl_->addContextCont(cont); + if (cont != nullptr) { + impl_->addContextCont(cont); + } return std::move(*this); } diff --git a/src/vt/runnable/runnable.cc b/src/vt/runnable/runnable.cc index 60fd4f5a44..10ac518088 100644 --- a/src/vt/runnable/runnable.cc +++ b/src/vt/runnable/runnable.cc @@ -56,6 +56,9 @@ namespace vt { namespace runnable { void RunnableNew::setupHandler(HandlerType handler) { using HandlerManagerType = HandlerManager; + bool const is_obj = HandlerManagerType::isHandlerObjGroup(handler); + vtAssert(not is_obj, "Must not be object"); + bool const is_auto = HandlerManagerType::isHandlerAuto(handler); bool const is_functor = HandlerManagerType::isHandlerFunctor(handler); diff --git a/src/vt/serialization/messaging/serialized_messenger.impl.h b/src/vt/serialization/messaging/serialized_messenger.impl.h index fa65da93d4..5b9c296ae0 100644 --- a/src/vt/serialization/messaging/serialized_messenger.impl.h +++ b/src/vt/serialization/messaging/serialized_messenger.impl.h @@ -53,6 +53,7 @@ #include "vt/serialization/messaging/serialized_messenger.h" #include "vt/messaging/envelope/envelope_set.h" // envelopeSetRef #include "vt/runnable/make_runnable.h" +#include "vt/objgroup/manager.fwd.h" #include #include @@ -88,9 +89,16 @@ template auto msg_data = ptr_offset; auto user_msg = deserializeFullMessage(msg_data); - runnable::makeRunnable(user_msg, true, handler, sys_msg->from_node) - .withTDEpochFromMsg() - .enqueue(); + bool const is_obj = HandlerManager::isHandlerObjGroup(handler); + if (is_obj) { + objgroup::dispatchObjGroup( + user_msg.template to(), handler, sys_msg->from_node, nullptr + ); + } else { + runnable::makeRunnable(user_msg, true, handler, sys_msg->from_node) + .withTDEpochFromMsg() + .enqueue(); + } } template @@ -132,10 +140,17 @@ template handler, recv_tag, envelopeGetEpoch(msg->env) ); - runnable::makeRunnable(msg, true, handler, node) - .withTDEpoch(epoch, not is_valid_epoch) - .withContinuation(action) - .enqueue(); + bool const is_obj = HandlerManager::isHandlerObjGroup(handler); + if (is_obj) { + objgroup::dispatchObjGroup( + msg.template to(), handler, node, action + ); + } else { + runnable::makeRunnable(msg, true, handler, node) + .withTDEpoch(epoch, not is_valid_epoch) + .withContinuation(action) + .enqueue(); + } if (is_valid_epoch) { theTerm()->consume(epoch); @@ -174,9 +189,16 @@ template print_ptr(user_msg.get()), envelopeGetEpoch(sys_msg->env) ); - runnable::makeRunnable(user_msg, true, handler, sys_msg->from_node) - .withTDEpochFromMsg() - .enqueue(); + bool const is_obj = HandlerManager::isHandlerObjGroup(handler); + if (is_obj) { + objgroup::dispatchObjGroup( + user_msg.template to(), handler, sys_msg->from_node, nullptr + ); + } else { + runnable::makeRunnable(user_msg, true, handler, sys_msg->from_node) + .withTDEpochFromMsg() + .enqueue(); + } } template @@ -407,9 +429,16 @@ template auto base_msg = user_msg.template to(); return messaging::PendingSend(base_msg, [=](MsgPtr in) { - runnable::makeRunnable(user_msg, true, typed_handler, node) - .withTDEpochFromMsg() - .enqueue(); + bool const is_obj = HandlerManager::isHandlerObjGroup(typed_handler); + if (is_obj) { + objgroup::dispatchObjGroup( + user_msg.template to(), typed_handler, node, nullptr + ); + } else { + runnable::makeRunnable(user_msg, true, typed_handler, node) + .withTDEpochFromMsg() + .enqueue(); + } }); } };