Skip to content

Commit

Permalink
#1983: runnable: implement fixes for objgroup dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
lifflander committed Sep 28, 2022
1 parent ec25141 commit 3680f52
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 42 deletions.
17 changes: 11 additions & 6 deletions src/vt/messaging/active.cc
Original file line number Diff line number Diff line change
Expand Up @@ -498,10 +498,16 @@ EventType ActiveMessenger::doMessageSend(
} else {
recordLBDataCommForSend(dest, base, base.size());

runnable::makeRunnable(base, true, envelopeGetHandler(msg->env), dest)
.withTDEpochFromMsg(is_term)
.withLBData(&bare_handler_lb_data_, bare_handler_dummy_elm_id_for_lb_data_)
.enqueue();
auto han = envelopeGetHandler(msg->env);
bool const is_obj = HandlerManagerType::isHandlerObjGroup(han);
if (is_obj) {
objgroup::dispatchObjGroup(base, han, dest, nullptr);
} else {
runnable::makeRunnable(base, true, envelopeGetHandler(msg->env), dest)
.withTDEpochFromMsg(is_term)
.withLBData(&bare_handler_lb_data_, bare_handler_dummy_elm_id_for_lb_data_)
.enqueue();
}
}
return no_event;
}
Expand Down Expand Up @@ -957,8 +963,7 @@ void ActiveMessenger::prepareActiveMsgToRun(

bool const is_obj = HandlerManagerType::isHandlerObjGroup(handler);
if (is_obj) {
vtAbortIf(cont != nullptr, "Must be nullptr");
objgroup::dispatchObjGroup(base, handler);
objgroup::dispatchObjGroup(base, handler, from_node, cont);
} else {
runnable::makeRunnable(base, not is_term, handler, from_node)
.withContinuation(cont)
Expand Down
4 changes: 3 additions & 1 deletion src/vt/objgroup/dispatch/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ struct Dispatch final : DispatchBase {

virtual ~Dispatch() = default;

void run(HandlerType han, BaseMessage* msg) override;
void run(
HandlerType han, BaseMessage* msg, NodeType from_node, ActionType cont
) override;
void* objPtr() const override { return obj_; }

private:
Expand Down
8 changes: 5 additions & 3 deletions src/vt/objgroup/dispatch/dispatch.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@
namespace vt { namespace objgroup { namespace dispatch {

template <typename ObjT>
void Dispatch<ObjT>::run(HandlerType han, BaseMessage* msg) {
//using ActiveFnType = void(ObjT::*)(vt::BaseMessage*);
void Dispatch<ObjT>::run(
HandlerType han, BaseMessage* msg, NodeType from_node, ActionType cont
) {
vtAssert(obj_ != nullptr, "Must have a valid object");

auto tmsg = static_cast<vt::Message*>(msg);
auto m = promoteMsg(tmsg);
runnable::makeRunnable(m, true, han, theContext()->getNode())
runnable::makeRunnable(m, true, han, from_node)
.withContinuation(cont)
.withObjGroup(obj_)
.withTDEpochFromMsg()
.enqueue();
Expand Down
4 changes: 3 additions & 1 deletion src/vt/objgroup/dispatch/dispatch_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ struct DispatchBase {
* Dispatch to the handler; the base is closed around the proper object
* pointer that is type-erased here
*/
virtual void run(HandlerType han, BaseMessage* msg) = 0;
virtual void run(
HandlerType han, BaseMessage* msg, NodeType from_node, ActionType cont
) = 0;
virtual void* objPtr() const = 0;

ObjGroupProxyType proxy() const { return proxy_; }
Expand Down
16 changes: 11 additions & 5 deletions src/vt/objgroup/manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ ObjGroupProxyType ObjGroupManager::getProxy(ObjGroupProxyType proxy) {
return proxy;
}

void ObjGroupManager::dispatch(MsgSharedPtr<ShortMessage> msg, HandlerType han) {
void ObjGroupManager::dispatch(
MsgSharedPtr<ShortMessage> msg, HandlerType han, NodeType from_node,
ActionType cont
) {
// Extract the control-bit sequence from the handler
auto const ctrl = HandlerManager::getHandlerControl(han);
vt_debug_print(
Expand All @@ -103,9 +106,9 @@ void ObjGroupManager::dispatch(MsgSharedPtr<ShortMessage> msg, HandlerType han)
if (epoch != no_epoch and epoch != term::any_epoch_sentinel) {
theTerm()->produce(epoch);
}
pending_[proxy].push_back(msg);
pending_[proxy].emplace_back(msg, from_node, cont, han);
} else {
dispatch_iter->second->run(han,msg.get());
dispatch_iter->second->run(han, msg.get(), from_node, cont);
}
}

Expand Down Expand Up @@ -163,12 +166,15 @@ elm::ElementIDStruct ObjGroupManager::getNextElm(ObjGroupProxyType proxy) {
}
}

void dispatchObjGroup(MsgSharedPtr<ShortMessage> msg, HandlerType han) {
void dispatchObjGroup(
MsgSharedPtr<ShortMessage> msg, HandlerType han, NodeType from_node,
ActionType cont
) {
vt_debug_print(
verbose, objgroup,
"dispatchObjGroup: han={:x}\n", han
);
return theObjGroup()->dispatch(msg,han);
return theObjGroup()->dispatch(msg, han, from_node, cont);
}

}} /* end namespace vt::objgroup */
5 changes: 4 additions & 1 deletion src/vt/objgroup/manager.fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ namespace detail {
holder::HolderBase* getHolderBase(HandlerType handler);
} /* end namespace detail */

void dispatchObjGroup(MsgSharedPtr<ShortMessage> msg, HandlerType han);
void dispatchObjGroup(
MsgSharedPtr<ShortMessage> msg, HandlerType han, NodeType from_node,
ActionType cont
);

template <typename MsgT>
messaging::PendingSend send(MsgSharedPtr<MsgT> msg, HandlerType han, NodeType node);
Expand Down
29 changes: 27 additions & 2 deletions src/vt/objgroup/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,29 @@ struct ObjGroupManager : runtime::component::Component<ObjGroupManager> {
using HolderBasePtrType = std::unique_ptr<HolderBaseType>;
using DispatchBaseType = dispatch::DispatchBase;
using DispatchBasePtrType = std::unique_ptr<DispatchBaseType>;
using MsgContainerType = std::vector<MsgSharedPtr<ShortMessage>>;
using PendingSendType = messaging::PendingSend;

private:
struct PendingRecv {
PendingRecv(
MsgSharedPtr<ShortMessage> in_msg, NodeType in_from_node,
ActionType in_cont, HandlerType in_han
) : msg_(in_msg),
from_node_(in_from_node),
cont_(in_cont),
han_(in_han)
{ }

MsgSharedPtr<ShortMessage> msg_;
NodeType from_node_ = uninitialized_destination;
ActionType cont_ = nullptr;
HandlerType han_ = uninitialized_handler;
};

public:
using MsgContainerType = std::vector<PendingRecv>;


/**
* \internal \brief Construct the ObjGroupManager
*/
Expand Down Expand Up @@ -338,8 +358,13 @@ struct ObjGroupManager : runtime::component::Component<ObjGroupManager> {
*
* \param[in] msg the message
* \param[in] han the handler to invoke
* \param[in] from_node the node it was from
* \param[in] cont optional continuation to execute after
*/
void dispatch(MsgSharedPtr<ShortMessage> msg, HandlerType han);
void dispatch(
MsgSharedPtr<ShortMessage> msg, HandlerType han, NodeType from_node,
ActionType cont
);

/**
* \internal \brief Send a message to an objgroup
Expand Down
16 changes: 7 additions & 9 deletions src/vt/objgroup/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,13 @@ void ObjGroupManager::regObjProxy(ObjT* obj, ObjGroupProxyType proxy) {
);
auto pending_iter = pending_.find(proxy);
if (pending_iter != pending_.end()) {
for (auto&& msg : pending_iter->second) {
theSched()->enqueue([msg]{
auto const handler = envelopeGetHandler(msg->env);
auto const epoch = envelopeGetEpoch(msg->env);
theObjGroup()->dispatch(msg,handler);
if (epoch != no_epoch) {
theTerm()->consume(epoch);
}
});
for (auto&& pending : pending_iter->second) {
auto const& msg = pending.msg_;
auto const epoch = envelopeGetEpoch(msg->env);
dispatch(msg, pending.han_, pending.from_node_, pending.cont_);
if (epoch != no_epoch) {
theTerm()->consume(epoch);
}
}
pending_.erase(pending_iter);
}
Expand Down
8 changes: 8 additions & 0 deletions src/vt/objgroup/manager.static.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,11 @@ messaging::PendingSend send(MsgSharedPtr<MsgT> msg, HandlerType han, NodeType de
return messaging::PendingSend{cur_epoch, [msg, han, cur_epoch, this_node](){
auto holder = detail::getHolderBase(han);
auto const& elm_id = holder->getElmID();
auto elm = holder->getPtr();
auto lb_data = &holder->getLBData();

runnable::makeRunnable(msg, true, han, this_node)
.withObjGroup(elm)
.withTDEpoch(cur_epoch)
.withLBData(lb_data, elm_id)
.enqueue();
Expand All @@ -89,8 +91,14 @@ void invoke(messaging::MsgPtrThief<MsgT> msg, HandlerType han, NodeType dest_nod
);

// this is a local invocation.. no thread required
auto holder = detail::getHolderBase(han);
auto const& elm_id = holder->getElmID();
auto elm = holder->getPtr();
auto lb_data = &holder->getLBData();
runnable::makeRunnable(msg.msg_, false, han, this_node)
.withObjGroup(elm)
.withTDEpochFromMsg()
.withLBData(lb_data, elm_id)
.run();
}

Expand Down
4 changes: 3 additions & 1 deletion src/vt/runnable/make_runnable.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ struct RunnableMaker {
* \param[in] cont the continuation
*/
RunnableMaker&& withContinuation(ActionType cont) {
impl_->addContextCont(cont);
if (cont != nullptr) {
impl_->addContextCont(cont);
}
return std::move(*this);
}

Expand Down
3 changes: 3 additions & 0 deletions src/vt/runnable/runnable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ namespace vt { namespace runnable {
void RunnableNew::setupHandler(HandlerType handler) {
using HandlerManagerType = HandlerManager;

bool const is_obj = HandlerManagerType::isHandlerObjGroup(handler);
vtAssert(not is_obj, "Must not be object");

bool const is_auto = HandlerManagerType::isHandlerAuto(handler);
bool const is_functor = HandlerManagerType::isHandlerFunctor(handler);

Expand Down
55 changes: 42 additions & 13 deletions src/vt/serialization/messaging/serialized_messenger.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#include "vt/serialization/messaging/serialized_messenger.h"
#include "vt/messaging/envelope/envelope_set.h" // envelopeSetRef
#include "vt/runnable/make_runnable.h"
#include "vt/objgroup/manager.fwd.h"

#include <tuple>
#include <type_traits>
Expand Down Expand Up @@ -88,9 +89,16 @@ template <typename UserMsgT>
auto msg_data = ptr_offset;
auto user_msg = deserializeFullMessage<UserMsgT>(msg_data);

runnable::makeRunnable(user_msg, true, handler, sys_msg->from_node)
.withTDEpochFromMsg()
.enqueue();
bool const is_obj = HandlerManager::isHandlerObjGroup(handler);
if (is_obj) {
objgroup::dispatchObjGroup(
user_msg.template to<BaseMsgType>(), handler, sys_msg->from_node, nullptr
);
} else {
runnable::makeRunnable(user_msg, true, handler, sys_msg->from_node)
.withTDEpochFromMsg()
.enqueue();
}
}

template <typename UserMsgT>
Expand Down Expand Up @@ -132,10 +140,17 @@ template <typename UserMsgT>
handler, recv_tag, envelopeGetEpoch(msg->env)
);

runnable::makeRunnable(msg, true, handler, node)
.withTDEpoch(epoch, not is_valid_epoch)
.withContinuation(action)
.enqueue();
bool const is_obj = HandlerManager::isHandlerObjGroup(handler);
if (is_obj) {
objgroup::dispatchObjGroup(
msg.template to<BaseMsgType>(), handler, node, action
);
} else {
runnable::makeRunnable(msg, true, handler, node)
.withTDEpoch(epoch, not is_valid_epoch)
.withContinuation(action)
.enqueue();
}

if (is_valid_epoch) {
theTerm()->consume(epoch);
Expand Down Expand Up @@ -174,9 +189,16 @@ template <typename UserMsgT, typename BaseEagerMsgT>
print_ptr(user_msg.get()), envelopeGetEpoch(sys_msg->env)
);

runnable::makeRunnable(user_msg, true, handler, sys_msg->from_node)
.withTDEpochFromMsg()
.enqueue();
bool const is_obj = HandlerManager::isHandlerObjGroup(handler);
if (is_obj) {
objgroup::dispatchObjGroup(
user_msg.template to<BaseMsgType>(), handler, sys_msg->from_node, nullptr
);
} else {
runnable::makeRunnable(user_msg, true, handler, sys_msg->from_node)
.withTDEpochFromMsg()
.enqueue();
}
}

template <typename MsgT, typename BaseT>
Expand Down Expand Up @@ -407,9 +429,16 @@ template <typename MsgT, typename BaseT>

auto base_msg = user_msg.template to<BaseMsgType>();
return messaging::PendingSend(base_msg, [=](MsgPtr<BaseMsgType> in) {
runnable::makeRunnable(user_msg, true, typed_handler, node)
.withTDEpochFromMsg()
.enqueue();
bool const is_obj = HandlerManager::isHandlerObjGroup(typed_handler);
if (is_obj) {
objgroup::dispatchObjGroup(
user_msg.template to<BaseMsgType>(), typed_handler, node, nullptr
);
} else {
runnable::makeRunnable(user_msg, true, typed_handler, node)
.withTDEpochFromMsg()
.enqueue();
}
});
}
};
Expand Down

0 comments on commit 3680f52

Please sign in to comment.