Skip to content

Commit

Permalink
#2281: Working AllreduceHolder with all (Collection/Group/ObjGroup) c…
Browse files Browse the repository at this point in the history
…omponents
  • Loading branch information
JacobDomagala committed Sep 21, 2024
1 parent dac1b1a commit 455829a
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 109 deletions.
70 changes: 62 additions & 8 deletions src/vt/collective/reduce/allreduce/allreduce_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ objgroup::proxy::Proxy<Rabenseifner> AllreduceHolder::addRabensifnerAllreducer(

col_reducers_[coll_proxy].first = obj_proxy.getProxy();

fmt::print(
"Adding new Rabenseifner reducer for collection={:x}\n", coll_proxy);
vt_debug_print(
verbose, allreduce, "Adding new Rabenseifner reducer for collection={:x}",
coll_proxy
);

return obj_proxy;
}
Expand All @@ -70,8 +72,11 @@ AllreduceHolder::addRecursiveDoublingAllreducer(
"recursive_doubling_allreducer", strong_proxy, strong_group, num_elems);

col_reducers_[coll_proxy].second = obj_proxy.getProxy();
fmt::print(
"Adding new RecursiveDoubling reducer for collection={:x}\n", coll_proxy);

vt_debug_print(
verbose, allreduce,
"Adding new RecursiveDoubling reducer for collection={:x}", coll_proxy
);

return obj_proxy;
}
Expand All @@ -85,9 +90,10 @@ AllreduceHolder::addRabensifnerAllreducer(detail::StrongGroup strong_group) {

group_reducers_[group].first = obj_proxy.getProxy();

fmt::print(
"Adding new Rabenseifner reducer for group={:x} Size={}\n", group,
group_reducers_.size());
vt_debug_print(
verbose, allreduce,
"Adding new Rabenseifner reducer for group={:x}", group
);

return obj_proxy;
}
Expand All @@ -100,13 +106,51 @@ AllreduceHolder::addRecursiveDoublingAllreducer(
auto obj_proxy = theObjGroup()->makeCollective<RecursiveDoubling>(
"recursive_doubling_allreducer", strong_group);

fmt::print("Adding new RecursiveDoubling reducer for group={:x}\n", group);
vt_debug_print(
verbose, allreduce,
"Adding new Rabenseifner reducer for group={:x}", group
);

group_reducers_[group].second = obj_proxy.getProxy();

return obj_proxy;
}

objgroup::proxy::Proxy<Rabenseifner>
AllreduceHolder::addRabensifnerAllreducer(detail::StrongObjGroup strong_objgroup) {
auto const objgroup = strong_objgroup.get();

auto obj_proxy = theObjGroup()->makeCollective<Rabenseifner>(
"rabenseifer_allreducer", strong_objgroup);

objgroup_reducers_[objgroup].first = obj_proxy.getProxy();

vt_debug_print(
verbose, allreduce,
"Adding new Rabenseifner reducer for objgroup={:x} Size={}\n", objgroup
);

return obj_proxy;
}

objgroup::proxy::Proxy<RecursiveDoubling>
AllreduceHolder::addRecursiveDoublingAllreducer(
detail::StrongObjGroup strong_objgroup) {
auto const objgroup = strong_objgroup.get();

auto obj_proxy = theObjGroup()->makeCollective<RecursiveDoubling>(
"recursive_doubling_allreducer", strong_objgroup);

vt_debug_print(
verbose, allreduce,
"Adding new RecursiveDoubling reducer for objgroup={:x}\n", objgroup
);

objgroup_reducers_[objgroup].second = obj_proxy.getProxy();

return obj_proxy;
}

void AllreduceHolder::remove(detail::StrongVrtProxy strong_proxy) {
auto const key = strong_proxy.get();

Expand All @@ -127,4 +171,14 @@ void AllreduceHolder::remove(detail::StrongGroup strong_group) {
}
}

void AllreduceHolder::remove(detail::StrongObjGroup strong_objgroup) {
auto const key = strong_objgroup.get();

auto it = objgroup_reducers_.find(key);

if (it != objgroup_reducers_.end()) {
objgroup_reducers_.erase(key);
}
}

} // namespace vt::collective::reduce::allreduce
33 changes: 33 additions & 0 deletions src/vt/collective/reduce/allreduce/allreduce_holder.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,33 @@ struct AllreduceHolder {
}
}

template <typename ReducerT>
static auto getAllreducer(detail::StrongObjGroup strong_objgroup) {
auto const objgroup = strong_objgroup.get();

if (auto it = objgroup_reducers_.find(objgroup); it == objgroup_reducers_.end()) {
objgroup_reducers_[objgroup] = {u64empty, u64empty};
}

if constexpr (std::is_same_v<ReducerT, RabenseifnerT>) {
auto untyped_proxy = objgroup_reducers_.at(objgroup).first;
if (untyped_proxy == u64empty) {
return addRabensifnerAllreducer(strong_objgroup);
} else {
return static_cast<vt::objgroup::proxy::Proxy<Rabenseifner>>(
untyped_proxy);
}
} else {
auto untyped_proxy = objgroup_reducers_.at(objgroup).second;
if (untyped_proxy == u64empty) {
return addRecursiveDoublingAllreducer(strong_objgroup);
} else {
return static_cast<vt::objgroup::proxy::Proxy<RecursiveDoubling>>(
untyped_proxy);
}
}
}

static objgroup::proxy::Proxy<Rabenseifner> addRabensifnerAllreducer(
detail::StrongVrtProxy strong_proxy, detail::StrongGroup strong_group,
size_t num_elems);
Expand All @@ -132,8 +159,14 @@ struct AllreduceHolder {
static objgroup::proxy::Proxy<RecursiveDoubling>
addRecursiveDoublingAllreducer(detail::StrongGroup strong_group);

static objgroup::proxy::Proxy<Rabenseifner>
addRabensifnerAllreducer(detail::StrongObjGroup strong_group);
static objgroup::proxy::Proxy<RecursiveDoubling>
addRecursiveDoublingAllreducer(detail::StrongObjGroup strong_group);

static void remove(detail::StrongVrtProxy strong_proxy);
static void remove(detail::StrongGroup strong_group);
static void remove(detail::StrongObjGroup strong_group);

static inline std::unordered_map<
VirtualProxyType, std::pair<RabenseifnerProxy, RecursiveDoublingProxy>>
Expand Down
50 changes: 15 additions & 35 deletions src/vt/collective/reduce/allreduce/rabenseifner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
*/

#include "vt/collective/reduce/allreduce/rabenseifner.h"
#include "vt/collective/reduce/allreduce/allreduce_holder.h"
#include "vt/configs/error/config_assert.h"
#include "vt/group/group_manager.h"

Expand All @@ -67,16 +68,7 @@ Rabenseifner::Rabenseifner(
}

is_even_ = this_node_ % 2 == 0;
is_part_of_adjustment_group_ = this_node_ < (2 * nprocs_rem_);
if (is_part_of_adjustment_group_) {
if (is_even_) {
vrt_node_ = this_node_ / 2;
} else {
vrt_node_ = -1;
}
} else {
vrt_node_ = this_node_ - nprocs_rem_;
}
initializeVrtNode();

vt_debug_print(
terse, allreduce,
Expand All @@ -93,21 +85,15 @@ Rabenseifner::Rabenseifner(detail::StrongGroup group)
num_steps_(static_cast<int32_t>(log2(num_nodes_))),
nprocs_pof2_(1 << num_steps_),
nprocs_rem_(num_nodes_ - nprocs_pof2_) {
std::string nodes_info;
for (auto& node : nodes_) {
nodes_info += fmt::format("{} ", node);
}
auto const is_default_group = theGroup()->isGroupDefault(group_);
auto const in_group = theGroup()->inGroup(group_);
auto const is_part_of_allreduce =
(not is_default_group and in_group) or
is_default_group;
(not is_default_group and in_group) or is_default_group;

vt_debug_print(
terse, allreduce,
"Rabenseifner: is_default_group={} is_part_of_allreduce={} num_nodes_={} "
"Nodes:[{}]\n",
is_default_group, is_part_of_allreduce, num_nodes_, nodes_info);
"Rabenseifner: is_default_group={} is_part_of_allreduce={} num_nodes_={} \n",
is_default_group, is_part_of_allreduce, num_nodes_);

if (not is_default_group and in_group) {
auto it = std::find(nodes_.begin(), nodes_.end(), theContext()->getNode());
Expand All @@ -120,16 +106,8 @@ Rabenseifner::Rabenseifner(detail::StrongGroup group)
// We collectively create this Reducer, so it's possible that not all Nodes are part of it
if (is_part_of_allreduce) {
is_even_ = this_node_ % 2 == 0;
is_part_of_adjustment_group_ = this_node_ < (2 * nprocs_rem_);
if (is_part_of_adjustment_group_) {
if (is_even_) {
vrt_node_ = this_node_ / 2;
} else {
vrt_node_ = -1;
}
} else {
vrt_node_ = this_node_ - nprocs_rem_;
}

initializeVrtNode();
}
}

Expand All @@ -141,17 +119,21 @@ Rabenseifner::Rabenseifner(detail::StrongObjGroup objgroup)
num_steps_(static_cast<int32_t>(log2(num_nodes_))),
nprocs_pof2_(1 << num_steps_),
nprocs_rem_(num_nodes_ - nprocs_pof2_) {

nodes_.resize(num_nodes_);
for (NodeType i = 0; i < theContext()->getNumNodes(); ++i) {
nodes_[i] = i;
}

initializeVrtNode();

vt_debug_print(
terse, allreduce,
"Rabenseifner: is_default_group={} is_part_of_allreduce={} num_nodes_={} \n",
true, true, num_nodes_);
}

// We collectively create this Reducer, so it's possible that not all Nodes are part of it
void Rabenseifner::initializeVrtNode() {
is_even_ = this_node_ % 2 == 0;
is_part_of_adjustment_group_ = this_node_ < (2 * nprocs_rem_);
if (is_part_of_adjustment_group_) {
Expand All @@ -166,12 +148,10 @@ Rabenseifner::Rabenseifner(detail::StrongObjGroup objgroup)
}

Rabenseifner::~Rabenseifner() {
if (collection_proxy_ != u64empty) {
// StateHolder::clearAll(detail::StrongVrtProxy{collection_proxy_});

} else if (objgroup_proxy_ != u64empty) {
if (objgroup_proxy_ != u64empty) {
StateHolder::clearAll(detail::StrongObjGroup{objgroup_proxy_});
} else {
AllreduceHolder::remove(detail::StrongObjGroup{objgroup_proxy_});
} else if(group_ != u64empty){
StateHolder::clearAll(detail::StrongGroup{group_});
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/vt/collective/reduce/allreduce/rabenseifner.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ struct Rabenseifner {
* \param objgroup ObjGroupProxy
*/
Rabenseifner(detail::StrongObjGroup objgroup);

void initializeVrtNode();

~Rabenseifner();

/**
Expand Down
38 changes: 14 additions & 24 deletions src/vt/collective/reduce/allreduce/recursive_doubling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
*/

#include "vt/collective/reduce/allreduce/recursive_doubling.h"
#include "vt/collective/reduce/scoping/strong_types.h"
#include "vt/group/group_manager.h"

namespace vt::collective::reduce::allreduce {
Expand All @@ -67,24 +66,15 @@ RecursiveDoubling::RecursiveDoubling(

is_even_ = this_node_ % 2 == 0;
is_part_of_adjustment_group_ = this_node_ < (2 * nprocs_rem_);
if (is_part_of_adjustment_group_) {
if (is_even_) {
vrt_node_ = this_node_ / 2;
} else {
vrt_node_ = -1;
}
} else {
vrt_node_ = this_node_ - nprocs_rem_;
}
initializeVrtNode();

vt_debug_print(
terse, allreduce,
"RecursiveDoubling (this={}): proxy={:x} proxy_={} local_num_elems={}\n",
print_ptr(this), proxy.get(), proxy_.getProxy(), local_num_elems_);
}

RecursiveDoubling::RecursiveDoubling(
detail::StrongObjGroup objgroup)
RecursiveDoubling::RecursiveDoubling(detail::StrongObjGroup objgroup)
: objgroup_proxy_(objgroup.get()),
local_num_elems_(1),
num_nodes_(theContext()->getNumNodes()),
Expand All @@ -94,15 +84,12 @@ RecursiveDoubling::RecursiveDoubling(
nprocs_pof2_(1 << num_steps_),
nprocs_rem_(num_nodes_ - nprocs_pof2_),
is_part_of_adjustment_group_(this_node_ < (2 * nprocs_rem_)) {
if (is_part_of_adjustment_group_) {
if (is_even_) {
vrt_node_ = this_node_ / 2;
} else {
vrt_node_ = -1;
}
} else {
vrt_node_ = this_node_ - nprocs_rem_;
nodes_.resize(num_nodes_);
for (NodeType i = 0; i < theContext()->getNumNodes(); ++i) {
nodes_[i] = i;
}

initializeVrtNode();
}

RecursiveDoubling::RecursiveDoubling(detail::StrongGroup group)
Expand All @@ -116,6 +103,10 @@ RecursiveDoubling::RecursiveDoubling(detail::StrongGroup group)
nprocs_pof2_(1 << num_steps_),
nprocs_rem_(num_nodes_ - nprocs_pof2_),
is_part_of_adjustment_group_(this_node_ < (2 * nprocs_rem_)) {
initializeVrtNode();
}

void RecursiveDoubling::initializeVrtNode() {
if (is_part_of_adjustment_group_) {
if (is_even_) {
vrt_node_ = this_node_ / 2;
Expand All @@ -128,11 +119,10 @@ RecursiveDoubling::RecursiveDoubling(detail::StrongGroup group)
}

RecursiveDoubling::~RecursiveDoubling() {
if (collection_proxy_ != u64empty) {
// StateHolder::clearAll(detail::StrongVrtProxy{collection_proxy_});
} else if (objgroup_proxy_ != u64empty) {
if (objgroup_proxy_ != u64empty) {
StateHolder::clearAll(detail::StrongObjGroup{objgroup_proxy_});
} else {
AllreduceHolder::remove(detail::StrongObjGroup{objgroup_proxy_});
} else if (group_ != u64empty) {
StateHolder::clearAll(detail::StrongGroup{group_});
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/vt/collective/reduce/allreduce/recursive_doubling.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ struct RecursiveDoubling {
*/
RecursiveDoubling(detail::StrongGroup group);

/**
* \brief Initialize vrt_node_ variable
*/
void initializeVrtNode();

~RecursiveDoubling();

/**
Expand Down
4 changes: 0 additions & 4 deletions src/vt/objgroup/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,6 @@ struct ObjGroupManager : runtime::component::Component<ObjGroupManager> {
std::unordered_map<ObjGroupProxyType, std::vector<ActionType>> pending_;
/// Map of object groups' labels
std::unordered_map<ObjGroupProxyType, std::string> labels_;
/// Recursive Doubling reducers
ReducerMapType reducers_recursive_doubling_;
/// Rabenseifner reducers
ReducerMapType reducers_rabenseifner_;
};

}} /* end namespace vt::objgroup */
Expand Down
Loading

0 comments on commit 455829a

Please sign in to comment.