diff --git a/src/madness/world/worldgop.h b/src/madness/world/worldgop.h index 34817a755c5..f1acc7b1452 100644 --- a/src/madness/world/worldgop.h +++ b/src/madness/world/worldgop.h @@ -783,30 +783,71 @@ namespace madness { void reduce(T* buf, std::size_t nelem, opT op) { ProcessID parent, child0, child1; world_.mpi.binary_tree_info(0, parent, child0, child1); - const std::size_t nelem_per_maxmsg = max_reducebcast_msg_size() / sizeof(T); - - auto buf0 = std::unique_ptr(new T[nelem_per_maxmsg]); - auto buf1 = std::unique_ptr(new T[nelem_per_maxmsg]); + const std::size_t nelem_per_maxmsg = + max_reducebcast_msg_size() / sizeof(T); + + const auto buf_size = ((sizeof(T) * std::min(nelem_per_maxmsg, nelem) + + std::alignment_of_v - 1) / + std::alignment_of_v) * std::alignment_of_v; + struct free_dtor { + void operator()(T *ptr) { + if (ptr != nullptr) + std::free(ptr); + }; + }; + using sptr_t = std::unique_ptr; + + sptr_t buf0; + auto aligned_buf_alloc = [&]() -> T* { + // posix_memalign requires alignment to be an integer multiple of sizeof(void*)!! so ensure that + const std::size_t alignment = + ((std::alignment_of_v + sizeof(void *) - 1) / + sizeof(void *)) * + sizeof(void *); +#ifdef HAVE_POSIX_MEMALIGN + void *ptr; + if (posix_memalign(&ptr, alignment, buf_size) != 0) { + throw std::bad_alloc(); + } + return static_cast(ptr); +#else + return static_cast(std::aligned_alloc(alignment, buf_size)); +#endif + }; + if (child0 != -1) + buf0 = sptr_t(aligned_buf_alloc(), + free_dtor{}); + sptr_t buf1(nullptr); + if (child1 != -1) + buf1 = sptr_t(aligned_buf_alloc(), + free_dtor{}); auto reduce_impl = [&,this](T* buf, size_t nelem) { MADNESS_ASSERT(nelem <= nelem_per_maxmsg); SafeMPI::Request req0, req1; Tag gsum_tag = world_.mpi.unique_tag(); - if (child0 != -1) req0 = world_.mpi.Irecv(buf0.get(), nelem*sizeof(T), MPI_BYTE, child0, gsum_tag); - if (child1 != -1) req1 = world_.mpi.Irecv(buf1.get(), nelem*sizeof(T), MPI_BYTE, child1, gsum_tag); + if (child0 != -1) + req0 = world_.mpi.Irecv(buf0.get(), nelem * sizeof(T), MPI_BYTE, + child0, gsum_tag); + if (child1 != -1) + req1 = world_.mpi.Irecv(buf1.get(), nelem * sizeof(T), MPI_BYTE, + child1, gsum_tag); if (child0 != -1) { World::await(req0); - for (long i=0; i<(long)nelem; ++i) buf[i] = op(buf[i],buf0[i]); + for (long i = 0; i < (long)nelem; ++i) + buf[i] = op(buf[i], buf0[i]); } if (child1 != -1) { World::await(req1); - for (long i=0; i<(long)nelem; ++i) buf[i] = op(buf[i],buf1[i]); + for (long i = 0; i < (long)nelem; ++i) + buf[i] = op(buf[i], buf1[i]); } if (parent != -1) { - req0 = world_.mpi.Isend(buf, nelem*sizeof(T), MPI_BYTE, parent, gsum_tag); + req0 = world_.mpi.Isend(buf, nelem * sizeof(T), MPI_BYTE, parent, + gsum_tag); World::await(req0); } @@ -903,7 +944,7 @@ namespace madness { /// Concatenate an STL vector of serializable stuff onto node 0 /// \param[in] v input vector - /// \param[in] bufsz the max of the result' must be less than std::numeric_limits::max() + /// \param[in] bufsz the max number of bytes in the result; must be less than std::numeric_limits::max() /// \return on rank 0 returns the concatenated vector, elsewhere returns an empty vector template std::vector concat0(const std::vector& v, size_t bufsz=1024*1024) { @@ -913,10 +954,28 @@ namespace madness { world_.mpi.binary_tree_info(0, parent, child0, child1); int child0_nbatch = 0, child1_nbatch = 0; - auto buf0 = std::unique_ptr(new std::byte[bufsz]); - auto buf1 = std::unique_ptr(new std::byte[bufsz]); + struct free_dtor { + void operator()(std::byte *ptr) { + if (ptr != nullptr) + std::free(ptr); + }; + }; + using sptr_t = std::unique_ptr; + + sptr_t buf0; + if (child0 != -1) + buf0 = sptr_t(static_cast(std::aligned_alloc( + std::alignment_of_v, bufsz)), + free_dtor{}); + sptr_t buf1; + if (child1 != -1) + buf1 = sptr_t(static_cast(std::aligned_alloc( + std::alignment_of_v, bufsz)), + free_dtor{}); + // transfer data in chunks at most this large - const int batch_size = static_cast(std::min(static_cast(max_reducebcast_msg_size()),bufsz)); + const int batch_size = static_cast( + std::min(static_cast(max_reducebcast_msg_size()),bufsz)); // precompute max # of tags any node ... will need, and allocate them on every node to avoid tag counter divergence const int max_nbatch = bufsz / batch_size;