Skip to content

Commit

Permalink
Merge pull request #547 from EricaCMitchell/pr-world_reduce
Browse files Browse the repository at this point in the history
Decrease time for world.gop.reduce
  • Loading branch information
evaleev authored Sep 16, 2024
2 parents 4e2c648 + ef73ffd commit 0d87255
Showing 1 changed file with 72 additions and 13 deletions.
85 changes: 72 additions & 13 deletions src/madness/world/worldgop.h
Original file line number Diff line number Diff line change
Expand Up @@ -783,30 +783,71 @@ namespace madness {
void reduce(T* buf, std::size_t nelem, opT op) {
ProcessID parent, child0, child1;
world_.mpi.binary_tree_info(0, parent, child0, child1);
const std::size_t nelem_per_maxmsg = max_reducebcast_msg_size() / sizeof(T);

auto buf0 = std::unique_ptr<T[]>(new T[nelem_per_maxmsg]);
auto buf1 = std::unique_ptr<T[]>(new T[nelem_per_maxmsg]);
const std::size_t nelem_per_maxmsg =
max_reducebcast_msg_size() / sizeof(T);

const auto buf_size = ((sizeof(T) * std::min(nelem_per_maxmsg, nelem) +
std::alignment_of_v<T> - 1) /
std::alignment_of_v<T>) * std::alignment_of_v<T>;
struct free_dtor {
void operator()(T *ptr) {
if (ptr != nullptr)
std::free(ptr);
};
};
using sptr_t = std::unique_ptr<T[], free_dtor>;

sptr_t buf0;
auto aligned_buf_alloc = [&]() -> T* {
// posix_memalign requires alignment to be an integer multiple of sizeof(void*)!! so ensure that
const std::size_t alignment =
((std::alignment_of_v<T> + sizeof(void *) - 1) /
sizeof(void *)) *
sizeof(void *);
#ifdef HAVE_POSIX_MEMALIGN
void *ptr;
if (posix_memalign(&ptr, alignment, buf_size) != 0) {
throw std::bad_alloc();
}
return static_cast<T *>(ptr);
#else
return static_cast<T*>(std::aligned_alloc(alignment, buf_size));
#endif
};
if (child0 != -1)
buf0 = sptr_t(aligned_buf_alloc(),
free_dtor{});
sptr_t buf1(nullptr);
if (child1 != -1)
buf1 = sptr_t(aligned_buf_alloc(),
free_dtor{});

auto reduce_impl = [&,this](T* buf, size_t nelem) {
MADNESS_ASSERT(nelem <= nelem_per_maxmsg);
SafeMPI::Request req0, req1;
Tag gsum_tag = world_.mpi.unique_tag();

if (child0 != -1) req0 = world_.mpi.Irecv(buf0.get(), nelem*sizeof(T), MPI_BYTE, child0, gsum_tag);
if (child1 != -1) req1 = world_.mpi.Irecv(buf1.get(), nelem*sizeof(T), MPI_BYTE, child1, gsum_tag);
if (child0 != -1)
req0 = world_.mpi.Irecv(buf0.get(), nelem * sizeof(T), MPI_BYTE,
child0, gsum_tag);
if (child1 != -1)
req1 = world_.mpi.Irecv(buf1.get(), nelem * sizeof(T), MPI_BYTE,
child1, gsum_tag);

if (child0 != -1) {
World::await(req0);
for (long i=0; i<(long)nelem; ++i) buf[i] = op(buf[i],buf0[i]);
for (long i = 0; i < (long)nelem; ++i)
buf[i] = op(buf[i], buf0[i]);
}
if (child1 != -1) {
World::await(req1);
for (long i=0; i<(long)nelem; ++i) buf[i] = op(buf[i],buf1[i]);
for (long i = 0; i < (long)nelem; ++i)
buf[i] = op(buf[i], buf1[i]);
}

if (parent != -1) {
req0 = world_.mpi.Isend(buf, nelem*sizeof(T), MPI_BYTE, parent, gsum_tag);
req0 = world_.mpi.Isend(buf, nelem * sizeof(T), MPI_BYTE, parent,
gsum_tag);
World::await(req0);
}

Expand Down Expand Up @@ -903,7 +944,7 @@ namespace madness {
/// Concatenate an STL vector of serializable stuff onto node 0

/// \param[in] v input vector
/// \param[in] bufsz the max of the result' must be less than std::numeric_limits<int>::max()
/// \param[in] bufsz the max number of bytes in the result; must be less than std::numeric_limits<int>::max()
/// \return on rank 0 returns the concatenated vector, elsewhere returns an empty vector
template <typename T>
std::vector<T> concat0(const std::vector<T>& v, size_t bufsz=1024*1024) {
Expand All @@ -913,10 +954,28 @@ namespace madness {
world_.mpi.binary_tree_info(0, parent, child0, child1);
int child0_nbatch = 0, child1_nbatch = 0;

auto buf0 = std::unique_ptr<std::byte[]>(new std::byte[bufsz]);
auto buf1 = std::unique_ptr<std::byte[]>(new std::byte[bufsz]);
struct free_dtor {
void operator()(std::byte *ptr) {
if (ptr != nullptr)
std::free(ptr);
};
};
using sptr_t = std::unique_ptr<std::byte[], free_dtor>;

sptr_t buf0;
if (child0 != -1)
buf0 = sptr_t(static_cast<std::byte *>(std::aligned_alloc(
std::alignment_of_v<T>, bufsz)),
free_dtor{});
sptr_t buf1;
if (child1 != -1)
buf1 = sptr_t(static_cast<std::byte *>(std::aligned_alloc(
std::alignment_of_v<T>, bufsz)),
free_dtor{});

// transfer data in chunks at most this large
const int batch_size = static_cast<int>(std::min(static_cast<size_t>(max_reducebcast_msg_size()),bufsz));
const int batch_size = static_cast<int>(
std::min(static_cast<size_t>(max_reducebcast_msg_size()),bufsz));

// precompute max # of tags any node ... will need, and allocate them on every node to avoid tag counter divergence
const int max_nbatch = bufsz / batch_size;
Expand Down

0 comments on commit 0d87255

Please sign in to comment.