Skip to content

Commit

Permalink
#702: reduce: make reduce a delayed call by returning PendingSend for…
Browse files Browse the repository at this point in the history
… collective reduce, collections, and objgroups

This makes returning a sequentialID impossible. In order for collections to order properly, add a reduceImmediate public overload set that takes the same parameters as reduce but returns the sequential id instead of a PendingSend. This is used internally by the new reduce and collection reduceMsgExprImpl, but can be called by the user.
  • Loading branch information
nmm0 committed Feb 28, 2020
1 parent 58a164b commit 7e8dc44
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 45 deletions.
48 changes: 42 additions & 6 deletions src/vt/collective/reduce/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
#include "vt/messaging/message.h"
#include "vt/collective/tree/tree.h"
#include "vt/utils/hash/hash_tuple.h"
#include "vt/messaging/pending_send.h"

#include <tuple>
#include <unordered_map>
Expand All @@ -72,14 +73,15 @@ struct Reduce : virtual collective::tree::Tree {
template <typename T>
using ReduceStateType = ReduceState<T>;
using ReduceNumType = typename ReduceState<void>::ReduceNumType;
using PendingSendType = messaging::PendingSend;

Reduce();
Reduce(GroupType const& group, collective::tree::Tree* in_tree);

template <typename MessageT, ActiveTypedFnType<MessageT>* f>
SequentialIDType reduce(
PendingSendType reduce(
NodeType root, MessageT* const msg, TagType tag = no_tag,
SequentialIDType seq = no_seq_id, ReduceNumType num_contrib = 1,
SequentialIDType in_seq = no_seq_id, ReduceNumType num_contrib = 1,
VirtualProxyType proxy = no_vrt_proxy,
ObjGroupProxyType obj_group = no_obj_group
);
Expand All @@ -91,9 +93,9 @@ struct Reduce : virtual collective::tree::Tree {
MsgT, OpT, collective::reduce::operators::ReduceCallback<MsgT>
>
>
SequentialIDType reduce(
PendingSendType reduce(
NodeType const& root, MsgT* msg, Callback<MsgT> cb,
TagType const& tag = no_tag, SequentialIDType const& seq = no_seq_id,
TagType const& tag = no_tag, SequentialIDType in_seq = no_seq_id,
ReduceNumType const& num_contrib = 1,
VirtualProxyType const& proxy = no_vrt_proxy
);
Expand All @@ -104,9 +106,43 @@ struct Reduce : virtual collective::tree::Tree {
typename MsgT,
ActiveTypedFnType<MsgT> *f = MsgT::template msgHandler<MsgT, OpT, FunctorT>
>
SequentialIDType reduce(
PendingSendType reduce(
NodeType const& root, MsgT* msg, TagType const& tag = no_tag,
SequentialIDType const& seq = no_seq_id, ReduceNumType const& num_contrib = 1,
SequentialIDType in_seq = no_seq_id, ReduceNumType const& num_contrib = 1,
VirtualProxyType const& proxy = no_vrt_proxy
);

template <typename MessageT, ActiveTypedFnType<MessageT>* f>
SequentialIDType reduceImmediate(
NodeType root, MessageT* const msg, TagType tag = no_tag,
SequentialIDType in_seq = no_seq_id, ReduceNumType num_contrib = 1,
VirtualProxyType proxy = no_vrt_proxy,
ObjGroupProxyType obj_group = no_obj_group
);

template <
typename OpT,
typename MsgT,
ActiveTypedFnType<MsgT> *f = MsgT::template msgHandler<
MsgT, OpT, collective::reduce::operators::ReduceCallback<MsgT>
>
>
SequentialIDType reduceImmediate(
NodeType const& root, MsgT* msg, Callback<MsgT> cb,
TagType const& tag = no_tag, SequentialIDType in_seq = no_seq_id,
ReduceNumType const& num_contrib = 1,
VirtualProxyType const& proxy = no_vrt_proxy
);

template <
typename OpT,
typename FunctorT,
typename MsgT,
ActiveTypedFnType<MsgT> *f = MsgT::template msgHandler<MsgT, OpT, FunctorT>
>
SequentialIDType reduceImmediate(
NodeType const& root, MsgT* msg, TagType const& tag = no_tag,
SequentialIDType in_seq = no_seq_id, ReduceNumType const& num_contrib = 1,
VirtualProxyType const& proxy = no_vrt_proxy
);

Expand Down
54 changes: 46 additions & 8 deletions src/vt/collective/reduce/reduce.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,29 +84,64 @@ template <typename MessageT>
}

template <typename OpT, typename MsgT, ActiveTypedFnType<MsgT> *f>
SequentialIDType Reduce::reduce(
Reduce::PendingSendType Reduce::reduce(
NodeType const& root, MsgT* msg, Callback<MsgT> cb, TagType const& tag,
SequentialIDType const& seq, ReduceNumType const& num_contrib,
SequentialIDType in_seq, ReduceNumType const& num_contrib,
VirtualProxyType const& proxy
) {
msg->setCallback(cb);
return reduce<MsgT,f>(root,msg,tag,seq,num_contrib,proxy);
return reduce<MsgT,f>(root,msg,tag,in_seq,num_contrib,proxy);
}

template <
typename OpT, typename FunctorT, typename MsgT, ActiveTypedFnType<MsgT> *f
>
SequentialIDType Reduce::reduce(
Reduce::PendingSendType Reduce::reduce(
NodeType const& root, MsgT* msg, TagType const& tag,
SequentialIDType const& seq, ReduceNumType const& num_contrib,
SequentialIDType in_seq, ReduceNumType const& num_contrib,
VirtualProxyType const& proxy
) {
return reduce<MsgT,f>(root,msg,tag,seq,num_contrib,proxy);
return reduce<MsgT,f>(root,msg,tag,in_seq,num_contrib,proxy);
}

template <typename MessageT, ActiveTypedFnType<MessageT>* f>
SequentialIDType Reduce::reduce(
NodeType root, MessageT* const msg, TagType tag, SequentialIDType seq,
Reduce::PendingSendType Reduce::reduce(
NodeType root, MessageT* const msg, TagType tag, SequentialIDType in_seq,
ReduceNumType num_contrib, VirtualProxyType proxy, ObjGroupProxyType objgroup
) {
auto msg_ptr = promoteMsg(msg);
// This currently will only set the out parameter in_seq when the pending send
// is resolved -- this might not be desirable as the pointer may have become
// invalid.
return PendingSendType{envelopeGetEpoch(msg_ptr->env), [=](){
reduceImmediate< MessageT, f >(root, msg_ptr.get(), tag, in_seq, num_contrib, proxy, objgroup);
} };
}

template <typename OpT, typename MsgT, ActiveTypedFnType<MsgT> *f>
SequentialIDType Reduce::reduceImmediate(
NodeType const& root, MsgT* msg, Callback<MsgT> cb, TagType const& tag,
SequentialIDType in_seq, ReduceNumType const& num_contrib,
VirtualProxyType const& proxy
) {
msg->setCallback(cb);
return reduceImmediate<MsgT,f>(root,msg,tag,in_seq,num_contrib,proxy);
}

template <
typename OpT, typename FunctorT, typename MsgT, ActiveTypedFnType<MsgT> *f
>
SequentialIDType Reduce::reduceImmediate(
NodeType const& root, MsgT* msg, TagType const& tag,
SequentialIDType in_seq, ReduceNumType const& num_contrib,
VirtualProxyType const& proxy
) {
return reduceImmediate<MsgT,f>(root,msg,tag,in_seq,num_contrib,proxy);
}

template <typename MessageT, ActiveTypedFnType<MessageT>* f>
SequentialIDType Reduce::reduceImmediate(
NodeType root, MessageT* const msg, TagType tag, SequentialIDType in_seq,
ReduceNumType num_contrib, VirtualProxyType proxy, ObjGroupProxyType objgroup
) {
if (group_ != default_group) {
Expand All @@ -125,6 +160,9 @@ SequentialIDType Reduce::reduce(
group_, msg->reduce_tag_, msg->reduce_seq_, msg->reduce_proxy_,
msg->reduce_objgroup_, num_contrib, print_ptr(msg), envelopeGetRef(msg->env)
);

auto seq = in_seq;

if (seq == no_seq_id) {
auto reduce_seq_lookup = std::make_tuple(proxy,tag,objgroup);
auto iter = next_seq_for_tag_.find(reduce_seq_lookup);
Expand Down
4 changes: 3 additions & 1 deletion src/vt/objgroup/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
#include "vt/objgroup/dispatch/dispatch.h"
#include "vt/messaging/message/message.h"
#include "vt/messaging/message/smart_ptr.h"
#include "vt/messaging/pending_send.h"

#include <memory>
#include <functional>
Expand All @@ -78,6 +79,7 @@ struct ObjGroupManager {
using DispatchBasePtrType = std::unique_ptr<DispatchBaseType>;
using MsgContainerType = std::vector<MsgVirtualPtrAny>;
using BaseProxyListType = std::set<ObjGroupProxyType>;
using PendingSendType = messaging::PendingSend;

ObjGroupManager() = default;

Expand Down Expand Up @@ -140,7 +142,7 @@ struct ObjGroupManager {
*/

template <typename ObjT, typename MsgT, ActiveTypedFnType<MsgT> *f>
EpochType reduce(
PendingSendType reduce(
ProxyType<ObjT> proxy, MsgSharedPtr<MsgT> msg, EpochType epoch, TagType tag
);

Expand Down
2 changes: 1 addition & 1 deletion src/vt/objgroup/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ void ObjGroupManager::broadcast(MsgSharedPtr<MsgT> msg, HandlerType han) {
}

template <typename ObjT, typename MsgT, ActiveTypedFnType<MsgT> *f>
EpochType ObjGroupManager::reduce(
ObjGroupManager::PendingSendType ObjGroupManager::reduce(
ProxyType<ObjT> proxy, MsgSharedPtr<MsgT> msg, EpochType epoch, TagType tag
) {
auto const root = 0;
Expand Down
9 changes: 6 additions & 3 deletions src/vt/objgroup/proxy/proxy_objgroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,15 @@
#include "vt/collective/reduce/operators/functors/none_op.h"
#include "vt/collective/reduce/operators/callback_op.h"
#include "vt/utils/static_checks/msg_ptr.h"
#include "vt/messaging/pending_send.h"

namespace vt { namespace objgroup { namespace proxy {

template <typename ObjT>
struct Proxy {

using PendingSendType = messaging::PendingSend;

Proxy() = default;
Proxy(Proxy const&) = default;
Proxy(Proxy&&) = default;
Expand Down Expand Up @@ -94,7 +97,7 @@ struct Proxy {
MsgT, OpT, collective::reduce::operators::ReduceCallback<MsgT>
>
>
EpochType reduce(
PendingSendType reduce(
MsgPtrT msg, Callback<MsgT> cb, EpochType epoch = no_epoch,
TagType tag = no_tag
) const;
Expand All @@ -106,7 +109,7 @@ struct Proxy {
typename MsgT = typename util::MsgPtrType<MsgPtrT>::MsgType,
ActiveTypedFnType<MsgT> *f = MsgT::template msgHandler<MsgT, OpT, FunctorT>
>
EpochType reduce(
PendingSendType reduce(
MsgPtrT msg, EpochType epoch = no_epoch, TagType tag = no_tag
) const;

Expand All @@ -115,7 +118,7 @@ struct Proxy {
typename MsgT = typename util::MsgPtrType<MsgPtrT>::MsgType,
ActiveTypedFnType<MsgT> *f
>
EpochType reduce(
PendingSendType reduce(
MsgPtrT msg, EpochType epoch = no_epoch, TagType tag = no_tag
) const;

Expand Down
6 changes: 3 additions & 3 deletions src/vt/objgroup/proxy/proxy_objgroup.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ template <typename ObjT>
template <
typename OpT, typename MsgPtrT, typename MsgT, ActiveTypedFnType<MsgT> *f
>
EpochType Proxy<ObjT>::reduce(
typename Proxy<ObjT>::PendingSendType Proxy<ObjT>::reduce(
MsgPtrT inmsg, Callback<MsgT> cb, EpochType epoch, TagType tag
) const {
auto proxy = Proxy<ObjT>(*this);
Expand All @@ -92,7 +92,7 @@ template <
typename OpT, typename FunctorT, typename MsgPtrT, typename MsgT,
ActiveTypedFnType<MsgT> *f
>
EpochType Proxy<ObjT>::reduce(
typename Proxy<ObjT>::PendingSendType Proxy<ObjT>::reduce(
MsgPtrT inmsg, EpochType epoch, TagType tag
) const {
auto proxy = Proxy<ObjT>(*this);
Expand All @@ -102,7 +102,7 @@ EpochType Proxy<ObjT>::reduce(

template <typename ObjT>
template <typename MsgPtrT, typename MsgT, ActiveTypedFnType<MsgT> *f>
EpochType Proxy<ObjT>::reduce(
typename Proxy<ObjT>::PendingSendType Proxy<ObjT>::reduce(
MsgPtrT inmsg, EpochType epoch, TagType tag
) const {
auto proxy = Proxy<ObjT>(*this);
Expand Down
28 changes: 24 additions & 4 deletions src/vt/vrt/collection/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,35 +389,55 @@ struct CollectionManager {
* Reduce all elements of a collection
*/
template <typename ColT, typename MsgT, ActiveTypedFnType<MsgT> *f>
SequentialIDType reduceMsg(
messaging::PendingSend reduceMsg(
CollectionProxyWrapType<ColT, typename ColT::IndexType> const& toProxy,
MsgT *const msg, SequentialIDType seq = no_seq_id,
TagType tag = no_tag, NodeType root_node = uninitialized_destination
);

template <typename ColT, typename MsgT, ActiveTypedFnType<MsgT> *f>
SequentialIDType reduceMsg(
messaging::PendingSend reduceMsg(
CollectionProxyWrapType<ColT, typename ColT::IndexType> const& toProxy,
MsgT *const msg, SequentialIDType seq, TagType tag,
typename ColT::IndexType const& idx
);

template <typename ColT, typename MsgT, ActiveTypedFnType<MsgT> *f>
SequentialIDType reduceMsgExpr(
messaging::PendingSend reduceMsgExpr(
CollectionProxyWrapType<ColT, typename ColT::IndexType> const& toProxy,
MsgT *const msg, ReduceIdxFuncType<typename ColT::IndexType> expr_fn,
SequentialIDType seq = no_seq_id, TagType tag = no_tag,
NodeType root_node = uninitialized_destination
);

template <typename ColT, typename MsgT, ActiveTypedFnType<MsgT> *f>
SequentialIDType reduceMsgExpr(
messaging::PendingSend reduceMsgExpr(
CollectionProxyWrapType<ColT, typename ColT::IndexType> const& toProxy,
MsgT *const msg, ReduceIdxFuncType<typename ColT::IndexType> expr_fn,
SequentialIDType seq, TagType tag,
typename ColT::IndexType const& idx
);

private:

template <typename ColT, typename MsgT, ActiveTypedFnType<MsgT> *f>
SequentialIDType reduceMsgExprImpl(
CollectionProxyWrapType<ColT, typename ColT::IndexType> const& toProxy,
MsgT *const msg, ReduceIdxFuncType<typename ColT::IndexType> expr_fn,
SequentialIDType seq = no_seq_id, TagType tag = no_tag,
NodeType root_node = uninitialized_destination
);

template <typename ColT, typename MsgT, ActiveTypedFnType<MsgT> *f>
SequentialIDType reduceMsgExprImpl(
CollectionProxyWrapType<ColT, typename ColT::IndexType> const& toProxy,
MsgT *const msg, ReduceIdxFuncType<typename ColT::IndexType> expr_fn,
SequentialIDType seq, TagType tag,
typename ColT::IndexType const& idx
);

public:

/*
* Broadcast message to all elements of a collection
*/
Expand Down
Loading

0 comments on commit 7e8dc44

Please sign in to comment.