From 921bd41f29b320a7dd3495b48cf860953172cddc Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Thu, 10 Jun 2021 14:21:24 +0000 Subject: [PATCH 1/8] adding test that breaks current implementation --- .../test/scheduler/test_remote_mpi_worlds.cpp | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index 84aeaa2cf..f275bc672 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -366,4 +366,54 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, senderThread.join(); localWorld.destroy(); } + +TEST_CASE_METHOD(RemoteMpiTestFixture, + "Test sending sync and async message to same host", + "[mpi]") +{ + // Allocate two ranks in total, one rank per host + this->setWorldsSizes(2, 1, 1); + int sendRank = 1; + int recvRank = 0; + std::vector messageData = { 0, 1, 2 }; + + // Initi world + MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + + std::thread senderThread([this, sendRank, recvRank] { + std::vector messageData = { 0, 1, 2 }; + + remoteWorld.initialiseFromMsg(msg); + + // Send message twice + remoteWorld.send( + sendRank, recvRank, BYTES(messageData.data()), MPI_INT, messageData.size()); + remoteWorld.send( + sendRank, recvRank, BYTES(messageData.data()), MPI_INT, messageData.size()); + + usleep(1000 * 500); + remoteWorld.destroy(); + }); + + // Receive one message asynchronously + std::vector asyncMessage(messageData.size(), 0); + int recvId = + localWorld.irecv(sendRank, recvRank, BYTES(asyncMessage.data()), MPI_INT, asyncMessage.size()); + + // Receive one message synchronously + std::vector syncMessage(messageData.size(), 0); + localWorld.recv(sendRank, recvRank, BYTES(syncMessage.data()), MPI_INT, syncMessage.size(), MPI_STATUS_IGNORE); + + // Wait for the async message + localWorld.awaitAsyncRequest(recvId); + + // Checks + REQUIRE(syncMessage == messageData); + REQUIRE(asyncMessage == messageData); + + // Destroy world + senderThread.join(); + localWorld.destroy(); +} } From 17a48d1f42735f31cb8bc08e6e777adaafdaa04a Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Fri, 11 Jun 2021 10:43:04 +0000 Subject: [PATCH 2/8] removing thread pool and implementing the umb --- include/faabric/scheduler/MpiThreadPool.h | 40 --- include/faabric/scheduler/MpiWorld.h | 54 ++- .../faabric/transport/MpiMessageEndpoint.h | 4 +- src/scheduler/CMakeLists.txt | 1 - src/scheduler/MpiThreadPool.cpp | 67 ---- src/scheduler/MpiWorld.cpp | 331 ++++++++++++------ .../test/scheduler/test_remote_mpi_worlds.cpp | 30 +- 7 files changed, 289 insertions(+), 238 deletions(-) delete mode 100644 include/faabric/scheduler/MpiThreadPool.h delete mode 100644 src/scheduler/MpiThreadPool.cpp diff --git a/include/faabric/scheduler/MpiThreadPool.h b/include/faabric/scheduler/MpiThreadPool.h deleted file mode 100644 index 4483c769d..000000000 --- a/include/faabric/scheduler/MpiThreadPool.h +++ /dev/null @@ -1,40 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#define QUEUE_SHUTDOWN -1 - -namespace faabric::scheduler { -typedef std::tuple, std::promise> - ReqQueueType; -typedef faabric::util::Queue MpiReqQueue; - -class MpiAsyncThreadPool -{ - public: - explicit MpiAsyncThreadPool(int nThreads); - - void shutdown(); - - int size; - - std::shared_ptr getMpiReqQueue(); - - private: - std::vector threadPool; - std::atomic isShutdown; - - std::shared_ptr localReqQueue; - - void entrypoint(int i); -}; -} diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index dfe07c99e..ee716c716 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -3,20 +3,45 @@ #include #include -#include #include -#include #include #include #include #include -#include +#include namespace faabric::scheduler { typedef faabric::util::Queue> InMemoryMpiQueue; +/* The untracked message buffer (UMB) keeps track of the asyncrhonous + * messages that we must have received (i.e. through an irecv call) but we + * still have not waited on (acknowledged). Messages are acknowledged either + * through a call to recv or a call to await. A call to recv will + * acknowledge (i.e. synchronously read from transport buffers) as many + * unacknowleged messages there are, plus one. + */ +struct UnackedMessageBuffer +{ + struct Arguments + { + int sendRank; + int recvRank; + uint8_t* buffer; + faabric_datatype_t* dataType; + int count; + faabric::MPIMessage::MPIMessageType messageType; + }; + + // We keep track of: the request id, its arguments, and the message it + // acknowledges (may be null if unacknowleged). All three lists should + // always be the same size. + std::list ids; + std::list args; + std::list> msgs; +}; + class MpiWorld { public: @@ -41,8 +66,6 @@ class MpiWorld void destroy(); - void shutdownThreadPool(); - void getCartesianRank(int rank, int maxDims, const int* dims, @@ -196,9 +219,6 @@ class MpiWorld std::string user; std::string function; - std::shared_ptr threadPool; - int getMpiThreadPoolSize(); - std::vector cartProcsPerDim; /* MPI internal messaging layer */ @@ -221,6 +241,24 @@ class MpiWorld int recvRank); void closeMpiMessageEndpoints(); + // Support for asyncrhonous communications + std::shared_ptr + getUnackedMessageBuffer(int sendRank, int recvRank); + std::shared_ptr recvBatchReturnLast(int sendRank, + int recvRank, + int batchSize = 0); + + /* Helper methods */ + void checkRanksRange(int sendRank, int recvRank); + + // Abstraction of the bulk of the recv work, shared among various functions + void doRecv(std::shared_ptr m, + uint8_t* buffer, + faabric_datatype_t* dataType, + int count, + MPI_Status* status, + faabric::MPIMessage::MPIMessageType messageType = + faabric::MPIMessage::NORMAL); }; } diff --git a/include/faabric/transport/MpiMessageEndpoint.h b/include/faabric/transport/MpiMessageEndpoint.h index 40067fc6b..79c402a12 100644 --- a/include/faabric/transport/MpiMessageEndpoint.h +++ b/include/faabric/transport/MpiMessageEndpoint.h @@ -14,8 +14,8 @@ faabric::MpiHostsToRanksMessage recvMpiHostRankMsg(); void sendMpiHostRankMsg(const std::string& hostIn, const faabric::MpiHostsToRanksMessage msg); -/* This class abstracts the notion of a communication channel between two MPI - * ranks. There will always be one rank local to this host, and one remote. +/* This class abstracts the notion of a communication channel between two remote + * MPI ranks. There will always be one rank local to this host, and one remote. * Note that the port is unique per (user, function, sendRank, recvRank) tuple. */ class MpiMessageEndpoint diff --git a/src/scheduler/CMakeLists.txt b/src/scheduler/CMakeLists.txt index ce4cb75d1..a1abd4374 100644 --- a/src/scheduler/CMakeLists.txt +++ b/src/scheduler/CMakeLists.txt @@ -10,7 +10,6 @@ set(LIB_FILES SnapshotServer.cpp SnapshotClient.cpp MpiContext.cpp - MpiThreadPool.cpp MpiWorldRegistry.cpp MpiWorld.cpp ${HEADERS} diff --git a/src/scheduler/MpiThreadPool.cpp b/src/scheduler/MpiThreadPool.cpp deleted file mode 100644 index 6f6f088db..000000000 --- a/src/scheduler/MpiThreadPool.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include -#include - -namespace faabric::scheduler { -MpiAsyncThreadPool::MpiAsyncThreadPool(int nThreads) - : size(nThreads) - , isShutdown(false) -{ - SPDLOG_DEBUG("Starting an MpiAsyncThreadPool of size {}", nThreads); - - // Initialize async. req queue - localReqQueue = std::make_shared(); - - // Initialize thread pool - for (int i = 0; i < nThreads; ++i) { - threadPool.emplace_back( - std::bind(&MpiAsyncThreadPool::entrypoint, this, i)); - } -} - -void MpiAsyncThreadPool::shutdown() -{ - SPDLOG_DEBUG("Shutting down MpiAsyncThreadPool"); - - for (auto& thread : threadPool) { - if (thread.joinable()) { - thread.join(); - } - } -} - -std::shared_ptr MpiAsyncThreadPool::getMpiReqQueue() -{ - return this->localReqQueue; -} - -void MpiAsyncThreadPool::entrypoint(int i) -{ - faabric::scheduler::ReqQueueType req; - - while (!this->isShutdown) { - req = getMpiReqQueue()->dequeue(); - - int id = std::get<0>(req); - std::function func = std::get<1>(req); - std::promise promise = std::move(std::get<2>(req)); - - // Detect shutdown condition - if (id == QUEUE_SHUTDOWN) { - // The shutdown tuple includes a TLS cleanup function that we run - // _once per thread_ and exit - func(); - if (!this->isShutdown) { - this->isShutdown = true; - } - SPDLOG_TRACE("Mpi thread {}/{} shut down", i + 1, size); - break; - } - - // Do the job without holding any locks - func(); - - // Notify we are done via the future - promise.set_value(); - } -} -} diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 077de1deb..c995d4c8e 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -5,10 +5,21 @@ #include #include -static thread_local std::unordered_map> futureMap; +#include + +#define MPI_IS_ISEND_REQUEST -1 + +/* Each MPI rank runs in a separate thread, however they interact with faabric + * as a library. Thus, we use thread_local storage to guarantee that each rank + * sees its own version of these data structures. + */ static thread_local std::vector< std::unique_ptr> mpiMessageEndpoints; +static thread_local std::vector< + std::shared_ptr> + unackedMessageBuffers; +static thread_local std::map> reqIdToRanks; namespace faabric::scheduler { MpiWorld::MpiWorld() @@ -97,25 +108,29 @@ std::shared_ptr MpiWorld::recvRemoteMpiMessage( return mpiMessageEndpoints[index]->recvMpiMessage(); } -int MpiWorld::getMpiThreadPoolSize() +// We want to lazily initialise this data structure because, given its thread +// local nature, we expect it to be quite sparse (i.e. filled with nullptr). +std::shared_ptr +MpiWorld::getUnackedMessageBuffer(int sendRank, int recvRank) { - int usableCores = faabric::util::getUsableCores(); - int worldSize = size; + // Lazy initialise all empty slots + if (unackedMessageBuffers.size() == 0) { + for (int i = 0; i < size * size; i++) { + unackedMessageBuffers.emplace_back(nullptr); + } + } - if ((worldSize > usableCores) && (worldSize % usableCores != 0)) { - SPDLOG_WARN("Over-provisioning threads in the MPI thread pool."); - SPDLOG_WARN("To avoid this, set an MPI world size multiple of the " - "number of cores per machine."); + // Get the index for the rank-host pair + int index = getIndexForRanks(sendRank, recvRank); + assert(index >= 0 && index < size * size); + + if (unackedMessageBuffers[index] == nullptr) { + unackedMessageBuffers.emplace( + unackedMessageBuffers.begin() + index, + std::make_shared()); } - // Note - adding one to the worldSize to prevent deadlocking in certain - // corner-cases. - // For instance, if issuing `worldSize` non-blocking recvs, followed by - // `worldSize` non-blocking sends, and nothing else, the application will - // deadlock as all worker threads will be blocking on `recv` calls. This - // scenario is remote, but feasible. We _assume_ that following the same - // pattern but doing `worldSize + 1` calls is deliberately malicious, and - // we can confidently fail and deadlock. - return std::min(worldSize + 1, usableCores); + + return unackedMessageBuffers[index]; } void MpiWorld::create(const faabric::Message& call, int newId, int newSize) @@ -125,8 +140,6 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize) function = call.function(); size = newSize; - threadPool = std::make_shared( - getMpiThreadPoolSize()); auto& sch = faabric::scheduler::getScheduler(); @@ -170,13 +183,37 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize) void MpiWorld::destroy() { - // Destroy once per host - if (!isDestroyed.test_and_set()) { - shutdownThreadPool(); + // Destroy once per thread the rank-specific data structures + // Remote message endpoints + if (mpiMessageEndpoints.size() > 0) { + for (auto& e : mpiMessageEndpoints) { + if (e != nullptr) { + e->close(); + } + } + mpiMessageEndpoints.clear(); + } + + // Unacked message buffers + if (unackedMessageBuffers.size() > 0) { + for (auto& umb : unackedMessageBuffers) { + if (umb != nullptr) { + assert(umb->ids.empty() && umb->msgs.empty() && + umb->args.empty()); + umb->ids.clear(); + umb->msgs.clear(); + umb->args.clear(); + } + } + unackedMessageBuffers.clear(); + } - // Note - we are deliberately not deleting the KV in the global state - // TODO - find a way to do this only from the master client + // Request to rank map + assert(reqIdToRanks.empty()); + reqIdToRanks.clear(); + // Destroy once per host the shared resources + if (!isDestroyed.test_and_set()) { // Wait (forever) until all ranks are done consuming their queues to // clear them. // Note - this means that an application with outstanding messages, i.e. @@ -190,40 +227,6 @@ void MpiWorld::destroy() } } -void MpiWorld::shutdownThreadPool() -{ - // When shutting down the thread pool, we also make sure we clean all thread - // local state by sending a clear message to the queue. Currently, we only - // need to close the function call clients - for (int i = 0; i < threadPool->size; i++) { - std::promise p; - threadPool->getMpiReqQueue()->enqueue( - std::make_tuple(QUEUE_SHUTDOWN, - std::bind(&MpiWorld::closeMpiMessageEndpoints, this), - std::move(p))); - } - - threadPool->shutdown(); - - // Lastly clean the main thread as well - closeMpiMessageEndpoints(); -} - -// TODO - remove -// Clear thread local state -void MpiWorld::closeMpiMessageEndpoints() -{ - if (mpiMessageEndpoints.size() > 0) { - // Close all open sockets - for (auto& e : mpiMessageEndpoints) { - if (e != nullptr) { - e->close(); - } - } - mpiMessageEndpoints.clear(); - } -} - void MpiWorld::initialiseFromMsg(const faabric::Message& msg, bool forceLocal) { id = msg.mpiworldid(); @@ -231,9 +234,6 @@ void MpiWorld::initialiseFromMsg(const faabric::Message& msg, bool forceLocal) function = msg.function(); size = msg.mpiworldsize(); - threadPool = std::make_shared( - getMpiThreadPoolSize()); - // Sometimes for testing purposes we may want to initialise a world in the // _same_ host we have created one (note that this would never happen in // reality). If so, we skip initialising resources already initialised @@ -394,6 +394,10 @@ void MpiWorld::shiftCartesianCoords(int rank, getRankFromCoords(source, dispCoordsBwd.data()); } +// Sending is already asynchronous in both transport layers we use: in-memory +// queues for local messages, and ZeroMQ sockets for remote messages. Thus, +// we can just send normally and return a requestId. Upon await, we'll return +// immediately. int MpiWorld::isend(int sendRank, int recvRank, const uint8_t* buffer, @@ -402,23 +406,14 @@ int MpiWorld::isend(int sendRank, faabric::MPIMessage::MPIMessageType messageType) { int requestId = (int)faabric::util::generateGid(); + auto it = reqIdToRanks.try_emplace( + requestId, MPI_IS_ISEND_REQUEST, MPI_IS_ISEND_REQUEST); + if (!it.second) { + SPDLOG_ERROR("Request ID {} is already present in map", requestId); + throw std::runtime_error("Request ID already in map"); + } - std::promise resultPromise; - std::future resultFuture = resultPromise.get_future(); - threadPool->getMpiReqQueue()->enqueue( - std::make_tuple(requestId, - std::bind(&MpiWorld::send, - this, - sendRank, - recvRank, - buffer, - dataType, - count, - messageType), - std::move(resultPromise))); - - // Place the promise in a map to wait for it later - futureMap.emplace(std::make_pair(requestId, std::move(resultFuture))); + send(sendRank, recvRank, buffer, dataType, count, messageType); return requestId; } @@ -431,24 +426,20 @@ int MpiWorld::irecv(int sendRank, faabric::MPIMessage::MPIMessageType messageType) { int requestId = (int)faabric::util::generateGid(); + auto it = reqIdToRanks.try_emplace(requestId, sendRank, recvRank); + if (!it.second) { + SPDLOG_ERROR("Request ID {} is already present in map", requestId); + throw std::runtime_error("Request ID already in map"); + } - std::promise resultPromise; - std::future resultFuture = resultPromise.get_future(); - threadPool->getMpiReqQueue()->enqueue( - std::make_tuple(requestId, - std::bind(&MpiWorld::recv, - this, - sendRank, - recvRank, - buffer, - dataType, - count, - nullptr, - messageType), - std::move(resultPromise))); - - // Place the promise in a map to wait for it later - futureMap.emplace(std::make_pair(requestId, std::move(resultFuture))); + faabric::scheduler::UnackedMessageBuffer::Arguments args = { + sendRank, recvRank, buffer, dataType, count, messageType + }; + + auto umb = getUnackedMessageBuffer(sendRank, recvRank); + umb->ids.push_back(requestId); + umb->args.push_back(args); + umb->msgs.push_back(nullptr); return requestId; } @@ -529,17 +520,21 @@ void MpiWorld::recv(int sendRank, const std::string otherHost = getHostForRank(sendRank); bool isLocal = otherHost == thisHost; - // Recv message - std::shared_ptr m; - if (isLocal) { - SPDLOG_TRACE("MPI - recv {} -> {}", sendRank, recvRank); - m = getLocalQueue(sendRank, recvRank)->dequeue(); - } else { - SPDLOG_TRACE("MPI - recv remote {} -> {}", sendRank, recvRank); - m = recvRemoteMpiMessage(sendRank, recvRank); - } - assert(m != nullptr); + // Recv message from underlying transport + std::shared_ptr m = + recvBatchReturnLast(sendRank, recvRank); + + // Do the processing + doRecv(m, buffer, dataType, count, status, messageType); +} +void MpiWorld::doRecv(std::shared_ptr m, + uint8_t* buffer, + faabric_datatype_t* dataType, + int count, + MPI_Status* status, + faabric::MPIMessage::MPIMessageType messageType) +{ // Assert message integrity // Note - this checks won't happen in Release builds assert(m->messagetype() == messageType); @@ -806,17 +801,65 @@ void MpiWorld::awaitAsyncRequest(int requestId) { SPDLOG_TRACE("MPI - await {}", requestId); - auto it = futureMap.find(requestId); - if (it == futureMap.end()) { - throw std::runtime_error( - fmt::format("Error: waiting for unrecognized request {}", requestId)); + // Get the corresponding send and recv ranks + auto it = reqIdToRanks.find(requestId); + if (it == reqIdToRanks.end()) { + SPDLOG_ERROR("Asynchronous request id not recognized: {}", requestId); + throw std::runtime_error("Unrecognized async request id"); + } + int sendRank = it->second.first; + int recvRank = it->second.second; + reqIdToRanks.erase(it); + + // If awaiting an isend request, return immediately as our transport layer + // is asynchronous, thus request is already handled + if (sendRank == MPI_IS_ISEND_REQUEST && recvRank == MPI_IS_ISEND_REQUEST) { + return; } - // This call blocks until requestId has finished. - it->second.wait(); - futureMap.erase(it); + std::shared_ptr umb = + getUnackedMessageBuffer(sendRank, recvRank); + + // The request id must be in the UMB, as an irecv must happen before an + // await. It is very likely that if the assert were to fail, the previous + // check would have also failed, thus why we assert here. + auto idIndex = std::find(umb->ids.begin(), umb->ids.end(), requestId); + assert(idIndex != umb->ids.end()); + auto idDistance = std::distance(umb->ids.begin(), idIndex); - SPDLOG_DEBUG("Finished awaitAsyncRequest on {}", requestId); + // Get the corresponding message + assert(idDistance < umb->msgs.size()); + auto msgIndex = umb->msgs.begin(); + std::advance(msgIndex, idDistance); + + // Get the corresponding request arguments + assert(idDistance < umb->args.size()); + auto argsIndex = umb->args.begin(); + std::advance(argsIndex, idDistance); + + std::shared_ptr m; + if (*msgIndex != nullptr) { + // This id has already been acknowledged by a recv call, so do the recv + m = *msgIndex; + } else { + // We need to acknowledge all messages not acknowledged from the + // begining until us + auto firstNullMsg = std::find(umb->msgs.begin(), msgIndex, nullptr); + m = recvBatchReturnLast( + sendRank, recvRank, std::distance(firstNullMsg, msgIndex) + 1); + } + + doRecv(m, + argsIndex->buffer, + argsIndex->dataType, + argsIndex->count, + MPI_STATUS_IGNORE, + argsIndex->messageType); + + // Remove the acknowledged indexes from the UMB + umb->ids.erase(idIndex); + umb->msgs.erase(msgIndex); + umb->args.erase(argsIndex); } void MpiWorld::reduce(int sendRank, @@ -1180,6 +1223,70 @@ void MpiWorld::initLocalQueues() } } +std::shared_ptr +MpiWorld::recvBatchReturnLast(int sendRank, int recvRank, int batchSize) +{ + std::shared_ptr umb = + getUnackedMessageBuffer(sendRank, recvRank); + + // When calling from recv, we set the batch size to zero and work + // out the total here + auto firstNullMsg = std::find(umb->msgs.begin(), umb->msgs.end(), nullptr); + if (batchSize == 0) { + batchSize = std::distance(firstNullMsg, umb->msgs.end()) + 1; + } + + // Work out whether the message is sent locally or from another host + assert(thisHost == getHostForRank(recvRank)); + const std::string otherHost = getHostForRank(sendRank); + bool isLocal = otherHost == thisHost; + + // Recv message: first we receive all messages for which there is an id + // in the unacknowleged buffer but no msg. Note that these messages + // (batchSize - 1) were `irecv`-ed before ours. + std::shared_ptr m; + auto it = firstNullMsg; + if (isLocal) { + // First receive messages that happened before us + for (int i = 0; i < batchSize - 1; i++) { + SPDLOG_TRACE("MPI - pending recv {} -> {}", sendRank, recvRank); + auto _m = getLocalQueue(sendRank, recvRank)->dequeue(); + + assert(_m != nullptr); + assert(*it == nullptr); + + // Put the unacked message in the UMB + *it = _m; + it++; + } + + // Finally receive the message corresponding to us + SPDLOG_TRACE("MPI - recv {} -> {}", sendRank, recvRank); + m = getLocalQueue(sendRank, recvRank)->dequeue(); + } else { + // First receive messages that happened before us + for (int i = 0; i < batchSize - 1; i++) { + SPDLOG_TRACE( + "MPI - pending remote recv {} -> {}", sendRank, recvRank); + auto _m = recvRemoteMpiMessage(sendRank, recvRank); + + assert(_m != nullptr); + assert(*it == nullptr); + + // Put the unacked message in the UMB + *it = _m; + it++; + } + + // Finally receive the message corresponding to us + SPDLOG_TRACE("MPI - recv remote {} -> {}", sendRank, recvRank); + m = recvRemoteMpiMessage(sendRank, recvRank); + } + assert(m != nullptr); + + return m; +} + int MpiWorld::getIndexForRanks(int sendRank, int recvRank) { int index = sendRank * size + recvRank; diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index f275bc672..6584e3c25 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -387,23 +387,37 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, remoteWorld.initialiseFromMsg(msg); // Send message twice - remoteWorld.send( - sendRank, recvRank, BYTES(messageData.data()), MPI_INT, messageData.size()); - remoteWorld.send( - sendRank, recvRank, BYTES(messageData.data()), MPI_INT, messageData.size()); + remoteWorld.send(sendRank, + recvRank, + BYTES(messageData.data()), + MPI_INT, + messageData.size()); + remoteWorld.send(sendRank, + recvRank, + BYTES(messageData.data()), + MPI_INT, + messageData.size()); usleep(1000 * 500); remoteWorld.destroy(); }); - // Receive one message asynchronously + // Receive one message asynchronously std::vector asyncMessage(messageData.size(), 0); - int recvId = - localWorld.irecv(sendRank, recvRank, BYTES(asyncMessage.data()), MPI_INT, asyncMessage.size()); + int recvId = localWorld.irecv(sendRank, + recvRank, + BYTES(asyncMessage.data()), + MPI_INT, + asyncMessage.size()); // Receive one message synchronously std::vector syncMessage(messageData.size(), 0); - localWorld.recv(sendRank, recvRank, BYTES(syncMessage.data()), MPI_INT, syncMessage.size(), MPI_STATUS_IGNORE); + localWorld.recv(sendRank, + recvRank, + BYTES(syncMessage.data()), + MPI_INT, + syncMessage.size(), + MPI_STATUS_IGNORE); // Wait for the async message localWorld.awaitAsyncRequest(recvId); From 0110ff4389a3c2be583fc5b39be7e4b66f7212ac Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Mon, 14 Jun 2021 08:28:39 +0000 Subject: [PATCH 3/8] adding more tests --- .../test/scheduler/test_remote_mpi_worlds.cpp | 145 +++++++++++++++++- 1 file changed, 137 insertions(+), 8 deletions(-) diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index 6584e3c25..55ec8d298 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -75,7 +75,6 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") // Send a message that should get sent to this host remoteWorld.send( rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size()); - usleep(1000 * 500); remoteWorld.destroy(); }); @@ -115,21 +114,18 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, faabric::util::setMockMode(false); std::thread senderThread([this, rankA, rankB, numMessages] { - std::vector messageData = { 0, 1, 2 }; - remoteWorld.initialiseFromMsg(msg); for (int i = 0; i < numMessages; i++) { - remoteWorld.send(rankB, rankA, BYTES(&i), MPI_INT, sizeof(int)); + remoteWorld.send(rankB, rankA, BYTES(&i), MPI_INT, 1); } - usleep(1000 * 500); remoteWorld.destroy(); }); int recv; for (int i = 0; i < numMessages; i++) { localWorld.recv( - rankB, rankA, BYTES(&recv), MPI_INT, sizeof(int), MPI_STATUS_IGNORE); + rankB, rankA, BYTES(&recv), MPI_INT, 1, MPI_STATUS_IGNORE); // Check in-order delivery if (i % (numMessages / 10) == 0) { @@ -377,7 +373,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, int recvRank = 0; std::vector messageData = { 0, 1, 2 }; - // Initi world + // Init world MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); @@ -398,7 +394,6 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MPI_INT, messageData.size()); - usleep(1000 * 500); remoteWorld.destroy(); }); @@ -430,4 +425,138 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, senderThread.join(); localWorld.destroy(); } + +TEST_CASE_METHOD(RemoteMpiTestFixture, + "Test receiving remote async requests out of order", + "[mpi]") +{ + // Allocate two ranks in total, one rank per host + this->setWorldsSizes(2, 1, 1); + int sendRank = 1; + int recvRank = 0; + + // Init world + MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + + std::thread senderThread([this, sendRank, recvRank] { + remoteWorld.initialiseFromMsg(msg); + + // Send different messages + for (int i = 0; i < 3; i++) { + remoteWorld.send(sendRank, + recvRank, + BYTES(&i), + MPI_INT, + 1); + } + + remoteWorld.destroy(); + }); + + // Receive two messages asynchronously + int recv1, recv2, recv3; + int recvId1 = localWorld.irecv(sendRank, + recvRank, + BYTES(&recv1), + MPI_INT, + 1); + + int recvId2 = localWorld.irecv(sendRank, + recvRank, + BYTES(&recv2), + MPI_INT, + 1); + + // Receive one message synchronously + localWorld.recv(sendRank, + recvRank, + BYTES(&recv3), + MPI_INT, + 1, + MPI_STATUS_IGNORE); + + SECTION("Wait out of order") + { + localWorld.awaitAsyncRequest(recvId2); + localWorld.awaitAsyncRequest(recvId1); + } + + SECTION("Wait in order") + { + localWorld.awaitAsyncRequest(recvId1); + localWorld.awaitAsyncRequest(recvId2); + } + + // Checks + REQUIRE(recv1 == 0); + REQUIRE(recv2 == 1); + REQUIRE(recv3 == 2); + + // Destroy world + senderThread.join(); + localWorld.destroy(); +} + +TEST_CASE_METHOD(RemoteMpiTestFixture, + "Test ring sendrecv across hosts", + "[mpi]") +{ + // Allocate two ranks in total, one rank per host + this->setWorldsSizes(3, 1, 2); + int worldSize = 3; + std::vector localRanks = {0}; + + // Init world + MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + + std::thread senderThread([this, worldSize] { + std::vector remoteRanks = {1, 2}; + remoteWorld.initialiseFromMsg(msg); + + // Send different messages + for (auto& rank : remoteRanks) { + int left = rank > 0 ? rank - 1 : worldSize - 1; + int right = (rank + 1) % worldSize; + int recvData = -1; + + remoteWorld.sendRecv(BYTES(&rank), + 1, + MPI_INT, + right, + BYTES(&recvData), + 1, + MPI_INT, + left, + rank, + MPI_STATUS_IGNORE); + } + + remoteWorld.destroy(); + }); + + for (auto& rank : localRanks) { + int left = rank > 0 ? rank - 1 : worldSize - 1; + int right = (rank + 1) % worldSize; + int recvData = -1; + + localWorld.sendRecv(BYTES(&rank), + 1, + MPI_INT, + right, + BYTES(&recvData), + 1, + MPI_INT, + left, + rank, + MPI_STATUS_IGNORE); + + REQUIRE(recvData == left); + } + + // Destroy world + senderThread.join(); + localWorld.destroy(); +} } From a63440581938f62d6fc9a97a82705970a3295350 Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Tue, 15 Jun 2021 07:50:24 +0000 Subject: [PATCH 4/8] introducing the mpi message buffer and encapsulating most logic there --- include/faabric/scheduler/MpiMessageBuffer.h | 55 +++++++ include/faabric/scheduler/MpiWorld.h | 31 +--- src/scheduler/CMakeLists.txt | 1 + src/scheduler/MpiMessageBuffer.cpp | 73 ++++++++++ src/scheduler/MpiWorld.cpp | 135 ++++++++---------- .../test/scheduler/test_remote_mpi_worlds.cpp | 34 ++--- 6 files changed, 199 insertions(+), 130 deletions(-) create mode 100644 include/faabric/scheduler/MpiMessageBuffer.h create mode 100644 src/scheduler/MpiMessageBuffer.cpp diff --git a/include/faabric/scheduler/MpiMessageBuffer.h b/include/faabric/scheduler/MpiMessageBuffer.h new file mode 100644 index 000000000..88f4ad7f6 --- /dev/null +++ b/include/faabric/scheduler/MpiMessageBuffer.h @@ -0,0 +1,55 @@ +#include +#include + +#include +#include + +namespace faabric::scheduler { +/* The MPI message buffer (MMB) keeps track of the asyncrhonous + * messages that we must have received (i.e. through an irecv call) but we + * still have not waited on (acknowledged). Messages are acknowledged either + * through a call to recv or a call to await. A call to recv will + * acknowledge (i.e. synchronously read from transport buffers) as many + * unacknowleged messages there are, plus one. + */ +class MpiMessageBuffer +{ + public: + struct Arguments + { + int requestId; + std::shared_ptr msg; + int sendRank; + int recvRank; + uint8_t* buffer; + faabric_datatype_t* dataType; + int count; + faabric::MPIMessage::MPIMessageType messageType; + }; + + void addMessage(Arguments arg); + + void deleteMessage(const std::list::iterator& argIt); + + bool isEmpty(); + + int size(); + + std::list::iterator getRequestArguments(int requestId); + + std::list::iterator getFirstNullMsgUntil( + const std::list::iterator& argIt); + + std::list::iterator getFirstNullMsg(); + + int getTotalUnackedMessagesUntil( + const std::list::iterator& argIt); + + int getTotalUnackedMessages(); + + private: + // We keep track of the request id and its arguments. Note that the message + // is part of the arguments and may be null if unacknowleged. + std::list args; +}; +} diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index ee716c716..f3b0d00c1 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -4,44 +4,17 @@ #include #include +#include #include #include #include #include -#include namespace faabric::scheduler { typedef faabric::util::Queue> InMemoryMpiQueue; -/* The untracked message buffer (UMB) keeps track of the asyncrhonous - * messages that we must have received (i.e. through an irecv call) but we - * still have not waited on (acknowledged). Messages are acknowledged either - * through a call to recv or a call to await. A call to recv will - * acknowledge (i.e. synchronously read from transport buffers) as many - * unacknowleged messages there are, plus one. - */ -struct UnackedMessageBuffer -{ - struct Arguments - { - int sendRank; - int recvRank; - uint8_t* buffer; - faabric_datatype_t* dataType; - int count; - faabric::MPIMessage::MPIMessageType messageType; - }; - - // We keep track of: the request id, its arguments, and the message it - // acknowledges (may be null if unacknowleged). All three lists should - // always be the same size. - std::list ids; - std::list args; - std::list> msgs; -}; - class MpiWorld { public: @@ -242,7 +215,7 @@ class MpiWorld void closeMpiMessageEndpoints(); // Support for asyncrhonous communications - std::shared_ptr + std::shared_ptr getUnackedMessageBuffer(int sendRank, int recvRank); std::shared_ptr recvBatchReturnLast(int sendRank, int recvRank, diff --git a/src/scheduler/CMakeLists.txt b/src/scheduler/CMakeLists.txt index a1abd4374..4e8c271ad 100644 --- a/src/scheduler/CMakeLists.txt +++ b/src/scheduler/CMakeLists.txt @@ -10,6 +10,7 @@ set(LIB_FILES SnapshotServer.cpp SnapshotClient.cpp MpiContext.cpp + MpiMessageBuffer.cpp MpiWorldRegistry.cpp MpiWorld.cpp ${HEADERS} diff --git a/src/scheduler/MpiMessageBuffer.cpp b/src/scheduler/MpiMessageBuffer.cpp new file mode 100644 index 000000000..3ddd2395a --- /dev/null +++ b/src/scheduler/MpiMessageBuffer.cpp @@ -0,0 +1,73 @@ +#include +#include + +namespace faabric::scheduler { +typedef std::list::iterator ArgListIterator; +void MpiMessageBuffer::addMessage(Arguments arg) +{ + // Ensure we are enqueueing a null message (i.e. unacknowleged) + assert(arg.msg == nullptr); + + args.push_back(arg); +} + +void MpiMessageBuffer::deleteMessage(const ArgListIterator& argIt) +{ + args.erase(argIt); + return; +} + +bool MpiMessageBuffer::isEmpty() +{ + return args.empty(); +} + +int MpiMessageBuffer::size() +{ + return args.size(); +} + +ArgListIterator MpiMessageBuffer::getRequestArguments(int requestId) +{ + // The request id must be in the UMB, as an irecv must happen before an + // await + ArgListIterator argIt = + std::find_if(args.begin(), args.end(), [requestId](Arguments args) { + return args.requestId == requestId; + }); + + // If it's not there, error out + if (argIt == args.end()) { + SPDLOG_ERROR("Asynchronous request id not in UMB: {}", requestId); + throw std::runtime_error("Async request not in buffer"); + } + + return argIt; +} + +ArgListIterator MpiMessageBuffer::getFirstNullMsgUntil( + const ArgListIterator& argItEnd) +{ + return std::find_if(args.begin(), argItEnd, [](Arguments args) { + return args.msg == nullptr; + }); +} + +ArgListIterator MpiMessageBuffer::getFirstNullMsg() +{ + return getFirstNullMsgUntil(args.end()); +} + +int MpiMessageBuffer::getTotalUnackedMessagesUntil( + const ArgListIterator& argItEnd) +{ + ArgListIterator firstNull = getFirstNullMsgUntil(argItEnd); + return std::distance(firstNull, argItEnd); +} + +int MpiMessageBuffer::getTotalUnackedMessages() +{ + ArgListIterator firstNull = getFirstNullMsg(); + return std::distance(firstNull, args.end()); +} +} diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index c995d4c8e..3ade3ff44 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -5,10 +5,6 @@ #include #include -#include - -#define MPI_IS_ISEND_REQUEST -1 - /* Each MPI rank runs in a separate thread, however they interact with faabric * as a library. Thus, we use thread_local storage to guarantee that each rank * sees its own version of these data structures. @@ -17,8 +13,9 @@ static thread_local std::vector< std::unique_ptr> mpiMessageEndpoints; static thread_local std::vector< - std::shared_ptr> + std::shared_ptr> unackedMessageBuffers; +static thread_local std::set iSendRequests; static thread_local std::map> reqIdToRanks; namespace faabric::scheduler { @@ -108,12 +105,12 @@ std::shared_ptr MpiWorld::recvRemoteMpiMessage( return mpiMessageEndpoints[index]->recvMpiMessage(); } -// We want to lazily initialise this data structure because, given its thread -// local nature, we expect it to be quite sparse (i.e. filled with nullptr). -std::shared_ptr +std::shared_ptr MpiWorld::getUnackedMessageBuffer(int sendRank, int recvRank) { - // Lazy initialise all empty slots + // We want to lazily initialise this data structure because, given its + // thread local nature, we expect it to be quite sparse (i.e. filled with + // nullptr). if (unackedMessageBuffers.size() == 0) { for (int i = 0; i < size * size; i++) { unackedMessageBuffers.emplace_back(nullptr); @@ -127,7 +124,7 @@ MpiWorld::getUnackedMessageBuffer(int sendRank, int recvRank) if (unackedMessageBuffers[index] == nullptr) { unackedMessageBuffers.emplace( unackedMessageBuffers.begin() + index, - std::make_shared()); + std::make_shared()); } return unackedMessageBuffers[index]; @@ -185,7 +182,7 @@ void MpiWorld::destroy() { // Destroy once per thread the rank-specific data structures // Remote message endpoints - if (mpiMessageEndpoints.size() > 0) { + if (!mpiMessageEndpoints.empty()) { for (auto& e : mpiMessageEndpoints) { if (e != nullptr) { e->close(); @@ -195,22 +192,28 @@ void MpiWorld::destroy() } // Unacked message buffers - if (unackedMessageBuffers.size() > 0) { + if (!unackedMessageBuffers.empty()) { for (auto& umb : unackedMessageBuffers) { if (umb != nullptr) { - assert(umb->ids.empty() && umb->msgs.empty() && - umb->args.empty()); - umb->ids.clear(); - umb->msgs.clear(); - umb->args.clear(); + if (!umb->isEmpty()) { + SPDLOG_ERROR("Destroying the MPI world with outstanding {}" + " messages in the message buffer", + umb->size()); + throw std::runtime_error( + "Destroying world with a non-empty MPI message buffer"); + } } } unackedMessageBuffers.clear(); } - // Request to rank map - assert(reqIdToRanks.empty()); - reqIdToRanks.clear(); + // Request to rank map should be empty + if (!reqIdToRanks.empty()) { + SPDLOG_ERROR( + "Destroying the MPI world with {} outstanding async requests", + reqIdToRanks.size()); + throw std::runtime_error("Destroying world with outstanding requests"); + } // Destroy once per host the shared resources if (!isDestroyed.test_and_set()) { @@ -406,12 +409,7 @@ int MpiWorld::isend(int sendRank, faabric::MPIMessage::MPIMessageType messageType) { int requestId = (int)faabric::util::generateGid(); - auto it = reqIdToRanks.try_emplace( - requestId, MPI_IS_ISEND_REQUEST, MPI_IS_ISEND_REQUEST); - if (!it.second) { - SPDLOG_ERROR("Request ID {} is already present in map", requestId); - throw std::runtime_error("Request ID already in map"); - } + iSendRequests.insert(requestId); send(sendRank, recvRank, buffer, dataType, count, messageType); @@ -426,20 +424,17 @@ int MpiWorld::irecv(int sendRank, faabric::MPIMessage::MPIMessageType messageType) { int requestId = (int)faabric::util::generateGid(); - auto it = reqIdToRanks.try_emplace(requestId, sendRank, recvRank); - if (!it.second) { - SPDLOG_ERROR("Request ID {} is already present in map", requestId); - throw std::runtime_error("Request ID already in map"); - } + reqIdToRanks.try_emplace(requestId, sendRank, recvRank); - faabric::scheduler::UnackedMessageBuffer::Arguments args = { - sendRank, recvRank, buffer, dataType, count, messageType + // Enqueue a request with a null-pointing message (i.e. unacknowleged) and + // the generated request id + faabric::scheduler::MpiMessageBuffer::Arguments args = { + requestId, nullptr, sendRank, recvRank, + buffer, dataType, count, messageType }; auto umb = getUnackedMessageBuffer(sendRank, recvRank); - umb->ids.push_back(requestId); - umb->args.push_back(args); - umb->msgs.push_back(nullptr); + umb->addMessage(args); return requestId; } @@ -801,8 +796,17 @@ void MpiWorld::awaitAsyncRequest(int requestId) { SPDLOG_TRACE("MPI - await {}", requestId); + auto iSendIt = iSendRequests.find(requestId); + if (iSendIt != iSendRequests.end()) { + iSendRequests.erase(iSendIt); + return; + } + // Get the corresponding send and recv ranks auto it = reqIdToRanks.find(requestId); + // If the request id is not in the map, the application has issued an + // await() without a previous `isend`, `irecv`, or the actual request id + // has been corrupted. In any case, we error out. if (it == reqIdToRanks.end()) { SPDLOG_ERROR("Asynchronous request id not recognized: {}", requestId); throw std::runtime_error("Unrecognized async request id"); @@ -811,42 +815,21 @@ void MpiWorld::awaitAsyncRequest(int requestId) int recvRank = it->second.second; reqIdToRanks.erase(it); - // If awaiting an isend request, return immediately as our transport layer - // is asynchronous, thus request is already handled - if (sendRank == MPI_IS_ISEND_REQUEST && recvRank == MPI_IS_ISEND_REQUEST) { - return; - } - - std::shared_ptr umb = + std::shared_ptr umb = getUnackedMessageBuffer(sendRank, recvRank); - // The request id must be in the UMB, as an irecv must happen before an - // await. It is very likely that if the assert were to fail, the previous - // check would have also failed, thus why we assert here. - auto idIndex = std::find(umb->ids.begin(), umb->ids.end(), requestId); - assert(idIndex != umb->ids.end()); - auto idDistance = std::distance(umb->ids.begin(), idIndex); - - // Get the corresponding message - assert(idDistance < umb->msgs.size()); - auto msgIndex = umb->msgs.begin(); - std::advance(msgIndex, idDistance); - - // Get the corresponding request arguments - assert(idDistance < umb->args.size()); - auto argsIndex = umb->args.begin(); - std::advance(argsIndex, idDistance); + std::list::iterator argsIndex = + umb->getRequestArguments(requestId); std::shared_ptr m; - if (*msgIndex != nullptr) { + if (argsIndex->msg != nullptr) { // This id has already been acknowledged by a recv call, so do the recv - m = *msgIndex; + m = argsIndex->msg; } else { // We need to acknowledge all messages not acknowledged from the // begining until us - auto firstNullMsg = std::find(umb->msgs.begin(), msgIndex, nullptr); m = recvBatchReturnLast( - sendRank, recvRank, std::distance(firstNullMsg, msgIndex) + 1); + sendRank, recvRank, umb->getTotalUnackedMessagesUntil(argsIndex) + 1); } doRecv(m, @@ -857,9 +840,7 @@ void MpiWorld::awaitAsyncRequest(int requestId) argsIndex->messageType); // Remove the acknowledged indexes from the UMB - umb->ids.erase(idIndex); - umb->msgs.erase(msgIndex); - umb->args.erase(argsIndex); + umb->deleteMessage(argsIndex); } void MpiWorld::reduce(int sendRank, @@ -1226,14 +1207,14 @@ void MpiWorld::initLocalQueues() std::shared_ptr MpiWorld::recvBatchReturnLast(int sendRank, int recvRank, int batchSize) { - std::shared_ptr umb = + std::shared_ptr umb = getUnackedMessageBuffer(sendRank, recvRank); // When calling from recv, we set the batch size to zero and work - // out the total here - auto firstNullMsg = std::find(umb->msgs.begin(), umb->msgs.end(), nullptr); + // out the total here. We want to acknowledge _all_ unacknowleged messages + // _and then_ receive ours (which is not in the MMB). if (batchSize == 0) { - batchSize = std::distance(firstNullMsg, umb->msgs.end()) + 1; + batchSize = umb->getTotalUnackedMessages() + 1; } // Work out whether the message is sent locally or from another host @@ -1245,7 +1226,7 @@ MpiWorld::recvBatchReturnLast(int sendRank, int recvRank, int batchSize) // in the unacknowleged buffer but no msg. Note that these messages // (batchSize - 1) were `irecv`-ed before ours. std::shared_ptr m; - auto it = firstNullMsg; + auto argsIt = umb->getFirstNullMsg(); if (isLocal) { // First receive messages that happened before us for (int i = 0; i < batchSize - 1; i++) { @@ -1253,11 +1234,11 @@ MpiWorld::recvBatchReturnLast(int sendRank, int recvRank, int batchSize) auto _m = getLocalQueue(sendRank, recvRank)->dequeue(); assert(_m != nullptr); - assert(*it == nullptr); + assert(argsIt->msg == nullptr); // Put the unacked message in the UMB - *it = _m; - it++; + argsIt->msg = _m; + argsIt++; } // Finally receive the message corresponding to us @@ -1271,11 +1252,11 @@ MpiWorld::recvBatchReturnLast(int sendRank, int recvRank, int batchSize) auto _m = recvRemoteMpiMessage(sendRank, recvRank); assert(_m != nullptr); - assert(*it == nullptr); + assert(argsIt->msg == nullptr); // Put the unacked message in the UMB - *it = _m; - it++; + argsIt->msg = _m; + argsIt++; } // Finally receive the message corresponding to us diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index 55ec8d298..6a70a3be6 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -444,11 +444,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, // Send different messages for (int i = 0; i < 3; i++) { - remoteWorld.send(sendRank, - recvRank, - BYTES(&i), - MPI_INT, - 1); + remoteWorld.send(sendRank, recvRank, BYTES(&i), MPI_INT, 1); } remoteWorld.destroy(); @@ -456,25 +452,15 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, // Receive two messages asynchronously int recv1, recv2, recv3; - int recvId1 = localWorld.irecv(sendRank, - recvRank, - BYTES(&recv1), - MPI_INT, - 1); - - int recvId2 = localWorld.irecv(sendRank, - recvRank, - BYTES(&recv2), - MPI_INT, - 1); + int recvId1 = + localWorld.irecv(sendRank, recvRank, BYTES(&recv1), MPI_INT, 1); + + int recvId2 = + localWorld.irecv(sendRank, recvRank, BYTES(&recv2), MPI_INT, 1); // Receive one message synchronously - localWorld.recv(sendRank, - recvRank, - BYTES(&recv3), - MPI_INT, - 1, - MPI_STATUS_IGNORE); + localWorld.recv( + sendRank, recvRank, BYTES(&recv3), MPI_INT, 1, MPI_STATUS_IGNORE); SECTION("Wait out of order") { @@ -505,14 +491,14 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, // Allocate two ranks in total, one rank per host this->setWorldsSizes(3, 1, 2); int worldSize = 3; - std::vector localRanks = {0}; + std::vector localRanks = { 0 }; // Init world MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); std::thread senderThread([this, worldSize] { - std::vector remoteRanks = {1, 2}; + std::vector remoteRanks = { 1, 2 }; remoteWorld.initialiseFromMsg(msg); // Send different messages From 3055782874eb9aba3f84c7a94a0c7a00d7d46a0d Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Tue, 15 Jun 2021 10:00:15 +0000 Subject: [PATCH 5/8] adding tests for the mpi message buffer + formatting --- include/faabric/scheduler/MpiMessageBuffer.h | 34 +++-- src/scheduler/MpiMessageBuffer.cpp | 24 ++-- src/scheduler/MpiWorld.cpp | 22 +--- .../scheduler/test_mpi_message_buffer.cpp | 119 ++++++++++++++++++ .../test/scheduler/test_remote_mpi_worlds.cpp | 31 ++--- 5 files changed, 171 insertions(+), 59 deletions(-) create mode 100644 tests/test/scheduler/test_mpi_message_buffer.cpp diff --git a/include/faabric/scheduler/MpiMessageBuffer.h b/include/faabric/scheduler/MpiMessageBuffer.h index 88f4ad7f6..9715ee346 100644 --- a/include/faabric/scheduler/MpiMessageBuffer.h +++ b/include/faabric/scheduler/MpiMessageBuffer.h @@ -10,11 +10,17 @@ namespace faabric::scheduler { * still have not waited on (acknowledged). Messages are acknowledged either * through a call to recv or a call to await. A call to recv will * acknowledge (i.e. synchronously read from transport buffers) as many - * unacknowleged messages there are, plus one. + * unacknowleged messages there are. A call to await with a request + * id as a parameter will acknowledge as many unacknowleged messages there are + * until said request id. */ class MpiMessageBuffer { public: + /* This structure holds the metadata for each Mpi message we keep in the + * buffer. Note that the message field will point to null if unacknowleged + * or to a valid message otherwise. + */ struct Arguments { int requestId; @@ -27,29 +33,39 @@ class MpiMessageBuffer faabric::MPIMessage::MPIMessageType messageType; }; - void addMessage(Arguments arg); - - void deleteMessage(const std::list::iterator& argIt); + /* Interface to query the buffer size */ bool isEmpty(); int size(); - std::list::iterator getRequestArguments(int requestId); + /* Interface to add and delete messages to the buffer */ - std::list::iterator getFirstNullMsgUntil( - const std::list::iterator& argIt); + void addMessage(Arguments arg); + + void deleteMessage(const std::list::iterator& argIt); + + /* Interface to get a pointer to a message in the MMB */ + + // Pointer to a message given its request id + std::list::iterator getRequestArguments(int requestId); + // Pointer to the first null-pointing (unacknowleged) message std::list::iterator getFirstNullMsg(); + /* Interface to ask for the number of unacknowleged messages */ + + // Unacknowledged messages until an iterator (used in await) int getTotalUnackedMessagesUntil( const std::list::iterator& argIt); + // Unacknowledged messages in the whole buffer (used in recv) int getTotalUnackedMessages(); private: - // We keep track of the request id and its arguments. Note that the message - // is part of the arguments and may be null if unacknowleged. std::list args; + + std::list::iterator getFirstNullMsgUntil( + const std::list::iterator& argIt); }; } diff --git a/src/scheduler/MpiMessageBuffer.cpp b/src/scheduler/MpiMessageBuffer.cpp index 3ddd2395a..4ddd672b2 100644 --- a/src/scheduler/MpiMessageBuffer.cpp +++ b/src/scheduler/MpiMessageBuffer.cpp @@ -3,33 +3,29 @@ namespace faabric::scheduler { typedef std::list::iterator ArgListIterator; -void MpiMessageBuffer::addMessage(Arguments arg) +bool MpiMessageBuffer::isEmpty() { - // Ensure we are enqueueing a null message (i.e. unacknowleged) - assert(arg.msg == nullptr); - - args.push_back(arg); + return args.empty(); } -void MpiMessageBuffer::deleteMessage(const ArgListIterator& argIt) +int MpiMessageBuffer::size() { - args.erase(argIt); - return; + return args.size(); } -bool MpiMessageBuffer::isEmpty() +void MpiMessageBuffer::addMessage(Arguments arg) { - return args.empty(); + args.push_back(arg); } -int MpiMessageBuffer::size() +void MpiMessageBuffer::deleteMessage(const ArgListIterator& argIt) { - return args.size(); + args.erase(argIt); } ArgListIterator MpiMessageBuffer::getRequestArguments(int requestId) { - // The request id must be in the UMB, as an irecv must happen before an + // The request id must be in the MMB, as an irecv must happen before an // await ArgListIterator argIt = std::find_if(args.begin(), args.end(), [requestId](Arguments args) { @@ -38,7 +34,7 @@ ArgListIterator MpiMessageBuffer::getRequestArguments(int requestId) // If it's not there, error out if (argIt == args.end()) { - SPDLOG_ERROR("Asynchronous request id not in UMB: {}", requestId); + SPDLOG_ERROR("Asynchronous request id not in buffer: {}", requestId); throw std::runtime_error("Async request not in buffer"); } diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 3ade3ff44..42ea0b8e3 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -5,9 +5,8 @@ #include #include -/* Each MPI rank runs in a separate thread, however they interact with faabric - * as a library. Thus, we use thread_local storage to guarantee that each rank - * sees its own version of these data structures. +/* Each MPI rank runs in a separate thread, thus we use TLS to maintain the + * per-rank data structures. */ static thread_local std::vector< std::unique_ptr> @@ -112,9 +111,7 @@ MpiWorld::getUnackedMessageBuffer(int sendRank, int recvRank) // thread local nature, we expect it to be quite sparse (i.e. filled with // nullptr). if (unackedMessageBuffers.size() == 0) { - for (int i = 0; i < size * size; i++) { - unackedMessageBuffers.emplace_back(nullptr); - } + unackedMessageBuffers.resize(size * size, nullptr); } // Get the index for the rank-host pair @@ -505,15 +502,6 @@ void MpiWorld::recv(int sendRank, { // Sanity-check input parameters checkRanksRange(sendRank, recvRank); - if (getHostForRank(recvRank) != thisHost) { - SPDLOG_ERROR("Trying to recv message into a non-local rank: {}", - recvRank); - throw std::runtime_error("Receiving message into non-local rank"); - } - - // Work out whether the message is sent locally or from another host - const std::string otherHost = getHostForRank(sendRank); - bool isLocal = otherHost == thisHost; // Recv message from underlying transport std::shared_ptr m = @@ -804,8 +792,8 @@ void MpiWorld::awaitAsyncRequest(int requestId) // Get the corresponding send and recv ranks auto it = reqIdToRanks.find(requestId); - // If the request id is not in the map, the application has issued an - // await() without a previous `isend`, `irecv`, or the actual request id + // If the request id is not in the map, the application either has issued an + // await without a previous isend/irecv, or the actual request id // has been corrupted. In any case, we error out. if (it == reqIdToRanks.end()) { SPDLOG_ERROR("Asynchronous request id not recognized: {}", requestId); diff --git a/tests/test/scheduler/test_mpi_message_buffer.cpp b/tests/test/scheduler/test_mpi_message_buffer.cpp new file mode 100644 index 000000000..16665771f --- /dev/null +++ b/tests/test/scheduler/test_mpi_message_buffer.cpp @@ -0,0 +1,119 @@ +#include + +#include +#include +#include + +using namespace faabric::scheduler; + +MpiMessageBuffer::Arguments genRandomArguments(bool nullMsg = true, + int overrideRequestId = -1) +{ + int requestId; + if (overrideRequestId != -1) { + requestId = overrideRequestId; + } else { + requestId = static_cast(faabric::util::generateGid()); + } + + MpiMessageBuffer::Arguments args = { + requestId, nullptr, 0, 1, + nullptr, MPI_INT, 0, faabric::MPIMessage::NORMAL + }; + + if (!nullMsg) { + args.msg = std::make_shared(); + } + + return args; +} + +namespace tests { +TEST_CASE("Test adding message to message buffer", "[mpi]") +{ + MpiMessageBuffer mmb; + REQUIRE(mmb.isEmpty()); + + REQUIRE_NOTHROW(mmb.addMessage(genRandomArguments())); + REQUIRE(mmb.size() == 1); +} + +TEST_CASE("Test deleting message from message buffer", "[mpi]") +{ + MpiMessageBuffer mmb; + REQUIRE(mmb.isEmpty()); + + mmb.addMessage(genRandomArguments()); + REQUIRE(mmb.size() == 1); + + auto it = mmb.getFirstNullMsg(); + REQUIRE_NOTHROW(mmb.deleteMessage(it)); + + REQUIRE(mmb.isEmpty()); +} + +TEST_CASE("Test getting an iterator from a request id", "[mpi]") +{ + MpiMessageBuffer mmb; + + int requestId = 1337; + mmb.addMessage(genRandomArguments(true, requestId)); + + auto it = mmb.getRequestArguments(requestId); + REQUIRE(it->requestId == requestId); +} + +TEST_CASE("Test getting first null message", "[mpi]") +{ + MpiMessageBuffer mmb; + + // Add first a non-null message + int requestId1 = 1; + mmb.addMessage(genRandomArguments(false, requestId1)); + + // Then add a null message + int requestId2 = 2; + mmb.addMessage(genRandomArguments(true, requestId2)); + + // Query for the first non-null message + auto it = mmb.getFirstNullMsg(); + REQUIRE(it->requestId == requestId2); +} + +TEST_CASE("Test getting total unacked messages in message buffer", "[mpi]") +{ + MpiMessageBuffer mmb; + + REQUIRE(mmb.getTotalUnackedMessages() == 0); + + // Add a non-null message + mmb.addMessage(genRandomArguments(false)); + + // Then a couple of null messages + mmb.addMessage(genRandomArguments(true)); + mmb.addMessage(genRandomArguments(true)); + + // Check that we have two unacked messages + REQUIRE(mmb.getTotalUnackedMessages() == 2); +} + +TEST_CASE("Test getting total unacked messages in message buffer range", + "[mpi]") +{ + MpiMessageBuffer mmb; + + // Add a non-null message + mmb.addMessage(genRandomArguments(false)); + + // Then a couple of null messages + int requestId = 1337; + mmb.addMessage(genRandomArguments(true)); + mmb.addMessage(genRandomArguments(true, requestId)); + + // Get an iterator to our second null message + auto it = mmb.getRequestArguments(requestId); + + // Check that we have only one unacked message until the iterator + REQUIRE(mmb.getTotalUnackedMessagesUntil(it) == 1); +} +} diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index 6a70a3be6..5436f2eb9 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -67,9 +67,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); - std::thread senderThread([this, rankA, rankB] { - std::vector messageData = { 0, 1, 2 }; - + std::thread senderThread([this, rankA, rankB, &messageData] { remoteWorld.initialiseFromMsg(msg); // Send a message that should get sent to this host @@ -78,21 +76,18 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") remoteWorld.destroy(); }); - SECTION("Check recv") - { - // Receive the message for the given rank - MPI_Status status{}; - auto buffer = new int[messageData.size()]; - localWorld.recv( - rankB, rankA, BYTES(buffer), MPI_INT, messageData.size(), &status); + // Receive the message for the given rank + MPI_Status status{}; + auto buffer = new int[messageData.size()]; + localWorld.recv( + rankB, rankA, BYTES(buffer), MPI_INT, messageData.size(), &status); - std::vector actual(buffer, buffer + messageData.size()); - REQUIRE(actual == messageData); + std::vector actual(buffer, buffer + messageData.size()); + REQUIRE(actual == messageData); - REQUIRE(status.MPI_SOURCE == rankB); - REQUIRE(status.MPI_ERROR == MPI_SUCCESS); - REQUIRE(status.bytesSize == messageData.size() * sizeof(int)); - } + REQUIRE(status.MPI_SOURCE == rankB); + REQUIRE(status.MPI_ERROR == MPI_SUCCESS); + REQUIRE(status.bytesSize == messageData.size() * sizeof(int)); // Destroy worlds senderThread.join(); @@ -377,9 +372,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); - std::thread senderThread([this, sendRank, recvRank] { - std::vector messageData = { 0, 1, 2 }; - + std::thread senderThread([this, sendRank, recvRank, &messageData] { remoteWorld.initialiseFromMsg(msg); // Send message twice From 0a7ab7c0fe42ffb47c356cb1614b946b92b0e46b Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Tue, 15 Jun 2021 15:02:30 +0000 Subject: [PATCH 6/8] pr comments --- .github/workflows/tests.yml | 2 +- include/faabric/scheduler/MpiMessageBuffer.h | 45 +++++++---- include/faabric/scheduler/MpiWorld.h | 4 +- src/scheduler/MpiMessageBuffer.cpp | 56 +++++++------ src/scheduler/MpiWorld.cpp | 79 ++++++++++--------- .../scheduler/test_mpi_message_buffer.cpp | 19 +++-- tests/test/scheduler/test_mpi_world.cpp | 57 +++++++++++++ 7 files changed, 170 insertions(+), 92 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 58b8536ba..bc80693a9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -68,7 +68,7 @@ jobs: run: inv dev.cc faabric_tests # --- Tests --- - name: "Run tests" - run: ./bin/faabric_tests + run: LOG_LEVEL=trace ./bin/faabric_tests working-directory: /build/faabric/static dist-tests: diff --git a/include/faabric/scheduler/MpiMessageBuffer.h b/include/faabric/scheduler/MpiMessageBuffer.h index 9715ee346..59cc39120 100644 --- a/include/faabric/scheduler/MpiMessageBuffer.h +++ b/include/faabric/scheduler/MpiMessageBuffer.h @@ -21,16 +21,25 @@ class MpiMessageBuffer * buffer. Note that the message field will point to null if unacknowleged * or to a valid message otherwise. */ - struct Arguments + class PendingAsyncMpiMessage { - int requestId; - std::shared_ptr msg; - int sendRank; - int recvRank; - uint8_t* buffer; - faabric_datatype_t* dataType; - int count; - faabric::MPIMessage::MPIMessageType messageType; + public: + int requestId = -1; + std::shared_ptr msg = nullptr; + int sendRank = -1; + int recvRank = -1; + uint8_t* buffer = nullptr; + faabric_datatype_t* dataType = nullptr; + int count = -1; + faabric::MPIMessage::MPIMessageType messageType = + faabric::MPIMessage::NORMAL; + + bool isAcknowledged() { return msg != nullptr; } + + void acknowledge(std::shared_ptr msgIn) + { + msg = msgIn; + } }; /* Interface to query the buffer size */ @@ -41,31 +50,33 @@ class MpiMessageBuffer /* Interface to add and delete messages to the buffer */ - void addMessage(Arguments arg); + void addMessage(PendingAsyncMpiMessage msg); - void deleteMessage(const std::list::iterator& argIt); + void deleteMessage( + const std::list::iterator& msgIt); /* Interface to get a pointer to a message in the MMB */ // Pointer to a message given its request id - std::list::iterator getRequestArguments(int requestId); + std::list::iterator getRequestPendingMsg( + int requestId); // Pointer to the first null-pointing (unacknowleged) message - std::list::iterator getFirstNullMsg(); + std::list::iterator getFirstNullMsg(); /* Interface to ask for the number of unacknowleged messages */ // Unacknowledged messages until an iterator (used in await) int getTotalUnackedMessagesUntil( - const std::list::iterator& argIt); + const std::list::iterator& msgIt); // Unacknowledged messages in the whole buffer (used in recv) int getTotalUnackedMessages(); private: - std::list args; + std::list pendingMsgs; - std::list::iterator getFirstNullMsgUntil( - const std::list::iterator& argIt); + std::list::iterator getFirstNullMsgUntil( + const std::list::iterator& msgIt); }; } diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index f3b0d00c1..256d23f97 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -215,8 +215,8 @@ class MpiWorld void closeMpiMessageEndpoints(); // Support for asyncrhonous communications - std::shared_ptr - getUnackedMessageBuffer(int sendRank, int recvRank); + std::shared_ptr getUnackedMessageBuffer(int sendRank, + int recvRank); std::shared_ptr recvBatchReturnLast(int sendRank, int recvRank, int batchSize = 0); diff --git a/src/scheduler/MpiMessageBuffer.cpp b/src/scheduler/MpiMessageBuffer.cpp index 4ddd672b2..8b2445e40 100644 --- a/src/scheduler/MpiMessageBuffer.cpp +++ b/src/scheduler/MpiMessageBuffer.cpp @@ -2,68 +2,72 @@ #include namespace faabric::scheduler { -typedef std::list::iterator ArgListIterator; +typedef std::list::iterator + MpiMessageIterator; bool MpiMessageBuffer::isEmpty() { - return args.empty(); + return pendingMsgs.empty(); } int MpiMessageBuffer::size() { - return args.size(); + return pendingMsgs.size(); } -void MpiMessageBuffer::addMessage(Arguments arg) +void MpiMessageBuffer::addMessage(PendingAsyncMpiMessage msg) { - args.push_back(arg); + pendingMsgs.push_back(msg); } -void MpiMessageBuffer::deleteMessage(const ArgListIterator& argIt) +void MpiMessageBuffer::deleteMessage(const MpiMessageIterator& msgIt) { - args.erase(argIt); + pendingMsgs.erase(msgIt); } -ArgListIterator MpiMessageBuffer::getRequestArguments(int requestId) +MpiMessageIterator MpiMessageBuffer::getRequestPendingMsg(int requestId) { // The request id must be in the MMB, as an irecv must happen before an // await - ArgListIterator argIt = - std::find_if(args.begin(), args.end(), [requestId](Arguments args) { - return args.requestId == requestId; - }); + MpiMessageIterator msgIt = + std::find_if(pendingMsgs.begin(), + pendingMsgs.end(), + [requestId](PendingAsyncMpiMessage pendingMsg) { + return pendingMsg.requestId == requestId; + }); // If it's not there, error out - if (argIt == args.end()) { + if (msgIt == pendingMsgs.end()) { SPDLOG_ERROR("Asynchronous request id not in buffer: {}", requestId); throw std::runtime_error("Async request not in buffer"); } - return argIt; + return msgIt; } -ArgListIterator MpiMessageBuffer::getFirstNullMsgUntil( - const ArgListIterator& argItEnd) +MpiMessageIterator MpiMessageBuffer::getFirstNullMsgUntil( + const MpiMessageIterator& msgItEnd) { - return std::find_if(args.begin(), argItEnd, [](Arguments args) { - return args.msg == nullptr; - }); + return std::find_if( + pendingMsgs.begin(), msgItEnd, [](PendingAsyncMpiMessage pendingMsg) { + return pendingMsg.msg == nullptr; + }); } -ArgListIterator MpiMessageBuffer::getFirstNullMsg() +MpiMessageIterator MpiMessageBuffer::getFirstNullMsg() { - return getFirstNullMsgUntil(args.end()); + return getFirstNullMsgUntil(pendingMsgs.end()); } int MpiMessageBuffer::getTotalUnackedMessagesUntil( - const ArgListIterator& argItEnd) + const MpiMessageIterator& msgItEnd) { - ArgListIterator firstNull = getFirstNullMsgUntil(argItEnd); - return std::distance(firstNull, argItEnd); + MpiMessageIterator firstNull = getFirstNullMsgUntil(msgItEnd); + return std::distance(firstNull, msgItEnd); } int MpiMessageBuffer::getTotalUnackedMessages() { - ArgListIterator firstNull = getFirstNullMsg(); - return std::distance(firstNull, args.end()); + MpiMessageIterator firstNull = getFirstNullMsg(); + return std::distance(firstNull, pendingMsgs.end()); } } diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 42ea0b8e3..f9716001a 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -207,11 +207,19 @@ void MpiWorld::destroy() // Request to rank map should be empty if (!reqIdToRanks.empty()) { SPDLOG_ERROR( - "Destroying the MPI world with {} outstanding async requests", + "Destroying the MPI world with {} outstanding irecv requests", reqIdToRanks.size()); throw std::runtime_error("Destroying world with outstanding requests"); } + // iSend set should be empty + if (!iSendRequests.empty()) { + SPDLOG_ERROR( + "Destroying the MPI world with {} outstanding isend requests", + iSendRequests.size()); + throw std::runtime_error("Destroying world with outstanding requests"); + } + // Destroy once per host the shared resources if (!isDestroyed.test_and_set()) { // Wait (forever) until all ranks are done consuming their queues to @@ -423,15 +431,19 @@ int MpiWorld::irecv(int sendRank, int requestId = (int)faabric::util::generateGid(); reqIdToRanks.try_emplace(requestId, sendRank, recvRank); - // Enqueue a request with a null-pointing message (i.e. unacknowleged) and - // the generated request id - faabric::scheduler::MpiMessageBuffer::Arguments args = { - requestId, nullptr, sendRank, recvRank, - buffer, dataType, count, messageType - }; + // Enqueue an unacknowleged request (no message) + faabric::scheduler::MpiMessageBuffer::PendingAsyncMpiMessage pendingMsg; + pendingMsg.requestId = requestId; + pendingMsg.sendRank = sendRank; + pendingMsg.recvRank = recvRank; + pendingMsg.buffer = buffer; + pendingMsg.dataType = dataType; + pendingMsg.count = count; + pendingMsg.messageType = messageType; + assert(!pendingMsg.isAcknowledged()); auto umb = getUnackedMessageBuffer(sendRank, recvRank); - umb->addMessage(args); + umb->addMessage(pendingMsg); return requestId; } @@ -806,29 +818,29 @@ void MpiWorld::awaitAsyncRequest(int requestId) std::shared_ptr umb = getUnackedMessageBuffer(sendRank, recvRank); - std::list::iterator argsIndex = - umb->getRequestArguments(requestId); + std::list::iterator msgIt = + umb->getRequestPendingMsg(requestId); std::shared_ptr m; - if (argsIndex->msg != nullptr) { + if (msgIt->msg != nullptr) { // This id has already been acknowledged by a recv call, so do the recv - m = argsIndex->msg; + m = msgIt->msg; } else { // We need to acknowledge all messages not acknowledged from the // begining until us m = recvBatchReturnLast( - sendRank, recvRank, umb->getTotalUnackedMessagesUntil(argsIndex) + 1); + sendRank, recvRank, umb->getTotalUnackedMessagesUntil(msgIt) + 1); } doRecv(m, - argsIndex->buffer, - argsIndex->dataType, - argsIndex->count, + msgIt->buffer, + msgIt->dataType, + msgIt->count, MPI_STATUS_IGNORE, - argsIndex->messageType); + msgIt->messageType); // Remove the acknowledged indexes from the UMB - umb->deleteMessage(argsIndex); + umb->deleteMessage(msgIt); } void MpiWorld::reduce(int sendRank, @@ -1213,47 +1225,42 @@ MpiWorld::recvBatchReturnLast(int sendRank, int recvRank, int batchSize) // Recv message: first we receive all messages for which there is an id // in the unacknowleged buffer but no msg. Note that these messages // (batchSize - 1) were `irecv`-ed before ours. - std::shared_ptr m; - auto argsIt = umb->getFirstNullMsg(); + std::shared_ptr ourMsg; + auto msgIt = umb->getFirstNullMsg(); if (isLocal) { // First receive messages that happened before us for (int i = 0; i < batchSize - 1; i++) { SPDLOG_TRACE("MPI - pending recv {} -> {}", sendRank, recvRank); - auto _m = getLocalQueue(sendRank, recvRank)->dequeue(); - - assert(_m != nullptr); - assert(argsIt->msg == nullptr); + auto pendingMsg = getLocalQueue(sendRank, recvRank)->dequeue(); // Put the unacked message in the UMB - argsIt->msg = _m; - argsIt++; + assert(!msgIt->isAcknowledged()); + msgIt->acknowledge(pendingMsg); + msgIt++; } // Finally receive the message corresponding to us SPDLOG_TRACE("MPI - recv {} -> {}", sendRank, recvRank); - m = getLocalQueue(sendRank, recvRank)->dequeue(); + ourMsg = getLocalQueue(sendRank, recvRank)->dequeue(); } else { // First receive messages that happened before us for (int i = 0; i < batchSize - 1; i++) { SPDLOG_TRACE( "MPI - pending remote recv {} -> {}", sendRank, recvRank); - auto _m = recvRemoteMpiMessage(sendRank, recvRank); - - assert(_m != nullptr); - assert(argsIt->msg == nullptr); + auto pendingMsg = recvRemoteMpiMessage(sendRank, recvRank); // Put the unacked message in the UMB - argsIt->msg = _m; - argsIt++; + assert(!msgIt->isAcknowledged()); + msgIt->acknowledge(pendingMsg); + msgIt++; } // Finally receive the message corresponding to us SPDLOG_TRACE("MPI - recv remote {} -> {}", sendRank, recvRank); - m = recvRemoteMpiMessage(sendRank, recvRank); + ourMsg = recvRemoteMpiMessage(sendRank, recvRank); } - assert(m != nullptr); - return m; + return ourMsg; } int MpiWorld::getIndexForRanks(int sendRank, int recvRank) diff --git a/tests/test/scheduler/test_mpi_message_buffer.cpp b/tests/test/scheduler/test_mpi_message_buffer.cpp index 16665771f..40a845c0e 100644 --- a/tests/test/scheduler/test_mpi_message_buffer.cpp +++ b/tests/test/scheduler/test_mpi_message_buffer.cpp @@ -6,8 +6,9 @@ using namespace faabric::scheduler; -MpiMessageBuffer::Arguments genRandomArguments(bool nullMsg = true, - int overrideRequestId = -1) +MpiMessageBuffer::PendingAsyncMpiMessage genRandomArguments( + bool nullMsg = true, + int overrideRequestId = -1) { int requestId; if (overrideRequestId != -1) { @@ -16,16 +17,14 @@ MpiMessageBuffer::Arguments genRandomArguments(bool nullMsg = true, requestId = static_cast(faabric::util::generateGid()); } - MpiMessageBuffer::Arguments args = { - requestId, nullptr, 0, 1, - nullptr, MPI_INT, 0, faabric::MPIMessage::NORMAL - }; + MpiMessageBuffer::PendingAsyncMpiMessage pendingMsg; + pendingMsg.requestId = requestId; if (!nullMsg) { - args.msg = std::make_shared(); + pendingMsg.msg = std::make_shared(); } - return args; + return pendingMsg; } namespace tests { @@ -59,7 +58,7 @@ TEST_CASE("Test getting an iterator from a request id", "[mpi]") int requestId = 1337; mmb.addMessage(genRandomArguments(true, requestId)); - auto it = mmb.getRequestArguments(requestId); + auto it = mmb.getRequestPendingMsg(requestId); REQUIRE(it->requestId == requestId); } @@ -111,7 +110,7 @@ TEST_CASE("Test getting total unacked messages in message buffer range", mmb.addMessage(genRandomArguments(true, requestId)); // Get an iterator to our second null message - auto it = mmb.getRequestArguments(requestId); + auto it = mmb.getRequestPendingMsg(requestId); // Check that we have only one unacked message until the iterator REQUIRE(mmb.getTotalUnackedMessagesUntil(it) == 1); diff --git a/tests/test/scheduler/test_mpi_world.cpp b/tests/test/scheduler/test_mpi_world.cpp index 3d626e10e..a4a3b0391 100644 --- a/tests/test/scheduler/test_mpi_world.cpp +++ b/tests/test/scheduler/test_mpi_world.cpp @@ -1202,4 +1202,61 @@ TEST_CASE_METHOD(MpiBaseTestFixture, "Test all-to-all", "[mpi]") world.destroy(); } + +TEST_CASE_METHOD(MpiTestFixture, + "Test can't destroy world with outstanding requests", + "[mpi]") +{ + int rankA = 0; + int rankB = 1; + int data = 9; + int actual = -1; + + SECTION("Outstanding irecv") + { + world.send(rankA, rankB, BYTES(&data), MPI_INT, 1); + int recvId = world.irecv(rankA, rankB, BYTES(&actual), MPI_INT, 1); + + REQUIRE_THROWS(world.destroy()); + + world.awaitAsyncRequest(recvId); + REQUIRE(actual == data); + } + + SECTION("Outstanding acknowledged irecv") + { + int data2 = 14; + int actual2 = -1; + + world.send(rankA, rankB, BYTES(&data), MPI_INT, 1); + world.send(rankA, rankB, BYTES(&data2), MPI_INT, 1); + int recvId = world.irecv(rankA, rankB, BYTES(&actual), MPI_INT, 1); + int recvId2 = world.irecv(rankA, rankB, BYTES(&actual2), MPI_INT, 1); + + REQUIRE_THROWS(world.destroy()); + + // Await for the second request, which will acknowledge the first one + // but not remove it from the pending message buffer + world.awaitAsyncRequest(recvId2); + + REQUIRE_THROWS(world.destroy()); + + // Await for the first one + world.awaitAsyncRequest(recvId); + + REQUIRE(actual == data); + REQUIRE(actual2 == data2); + } + + SECTION("Outstanding isend") + { + int sendId = world.isend(rankA, rankB, BYTES(&data), MPI_INT, 1); + world.recv(rankA, rankB, BYTES(&actual), MPI_INT, 1, MPI_STATUS_IGNORE); + + REQUIRE_THROWS(world.destroy()); + + world.awaitAsyncRequest(sendId); + REQUIRE(actual == data); + } +} } From 88bcb7d6a6c88f8118301d8948fd9b39e40f3acc Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Wed, 16 Jun 2021 08:03:36 +0000 Subject: [PATCH 7/8] switching to per-world port range --- .github/workflows/tests.yml | 2 +- include/faabric/scheduler/MpiWorld.h | 9 +- include/faabric/transport/MessageEndpoint.h | 2 - .../faabric/transport/MpiMessageEndpoint.h | 4 +- src/proto/faabric.proto | 1 + src/scheduler/MpiWorld.cpp | 131 ++++++++++++------ src/transport/MessageEndpoint.cpp | 5 - src/transport/MpiMessageEndpoint.cpp | 8 +- .../transport/test_mpi_message_endpoint.cpp | 4 +- 9 files changed, 104 insertions(+), 62 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index bc80693a9..58b8536ba 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -68,7 +68,7 @@ jobs: run: inv dev.cc faabric_tests # --- Tests --- - name: "Run tests" - run: LOG_LEVEL=trace ./bin/faabric_tests + run: ./bin/faabric_tests working-directory: /build/faabric/static dist-tests: diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index 256d23f97..fc91d4bbb 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -27,7 +27,7 @@ class MpiWorld std::string getHostForRank(int rank); - void setAllRankHosts(const faabric::MpiHostsToRanksMessage& msg); + void setAllRankHostsPorts(const faabric::MpiHostsToRanksMessage& msg); std::string getUser(); @@ -205,8 +205,11 @@ class MpiWorld void initLocalQueues(); // Rank-to-rank sockets for remote messaging - void initRemoteMpiEndpoint(int sendRank, int recvRank); - int getMpiPort(int sendRank, int recvRank); + std::vector basePorts; + std::vector initLocalBasePorts( + const std::vector& executedAt); + void initRemoteMpiEndpoint(int localRank, int remoteRank); + std::pair getPortForRanks(int localRank, int remoteRank); void sendRemoteMpiMessage(int sendRank, int recvRank, const std::shared_ptr& msg); diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 658448455..4e463e07a 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -83,8 +83,6 @@ class RecvMessageEndpoint : public MessageEndpoint public: RecvMessageEndpoint(int portIn); - RecvMessageEndpoint(int portIn, const std::string& overrideHost); - void open(MessageContext& context); void close(); diff --git a/include/faabric/transport/MpiMessageEndpoint.h b/include/faabric/transport/MpiMessageEndpoint.h index 79c402a12..991331bc5 100644 --- a/include/faabric/transport/MpiMessageEndpoint.h +++ b/include/faabric/transport/MpiMessageEndpoint.h @@ -23,9 +23,7 @@ class MpiMessageEndpoint public: MpiMessageEndpoint(const std::string& hostIn, int portIn); - MpiMessageEndpoint(const std::string& hostIn, - int portIn, - const std::string& overrideRecvHost); + MpiMessageEndpoint(const std::string& hostIn, int sendPort, int recvPort); void sendMpiMessage(const std::shared_ptr& msg); diff --git a/src/proto/faabric.proto b/src/proto/faabric.proto index 80103ca1d..a44a4668c 100644 --- a/src/proto/faabric.proto +++ b/src/proto/faabric.proto @@ -88,6 +88,7 @@ message MPIMessage { // fields. message MpiHostsToRanksMessage { repeated string hosts = 1; + repeated int32 basePorts = 2; } message Message { diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index f9716001a..5f9e7bfc1 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -26,8 +26,12 @@ MpiWorld::MpiWorld() , cartProcsPerDim(2) {} -void MpiWorld::initRemoteMpiEndpoint(int sendRank, int recvRank) +void MpiWorld::initRemoteMpiEndpoint(int localRank, int remoteRank) { + SPDLOG_TRACE("Open MPI endpoint between ranks (local-remote) {} - {}", + localRank, + remoteRank); + // Resize the message endpoint vector and initialise to null. Note that we // allocate size x size slots to cover all possible (sendRank, recvRank) // pairs @@ -37,42 +41,20 @@ void MpiWorld::initRemoteMpiEndpoint(int sendRank, int recvRank) } } - // Get host for recv rank - std::string otherHost; - std::string recvHost = getHostForRank(recvRank); - std::string sendHost = getHostForRank(sendRank); - if (recvHost == sendHost) { - SPDLOG_ERROR( - "Send and recv ranks in the same host: SEND {}, RECV{} in {}", - sendRank, - recvRank, - sendHost); - throw std::runtime_error("Send and recv ranks in the same host"); - } else if (recvHost == thisHost) { - otherHost = sendHost; - } else if (sendHost == thisHost) { - otherHost = recvHost; - } else { - SPDLOG_ERROR("Send and recv ranks correspond to remote hosts: SEND {} " - "in {}, RECV {} in {}", - sendRank, - sendHost, - recvRank, - recvHost); - throw std::runtime_error("Send and recv ranks in remote hosts"); - } + // Get host for remote rank + std::string otherHost = getHostForRank(remoteRank); // Get the index for the rank-host pair - int index = getIndexForRanks(sendRank, recvRank); + int index = getIndexForRanks(localRank, remoteRank); // Get port for send-recv pair - int port = getMpiPort(sendRank, recvRank); + std::pair sendRecvPorts = getPortForRanks(localRank, remoteRank); // Create MPI message endpoint mpiMessageEndpoints.emplace( mpiMessageEndpoints.begin() + index, std::make_unique( - otherHost, port, thisHost)); + otherHost, sendRecvPorts.first, sendRecvPorts.second)); } void MpiWorld::sendRemoteMpiMessage( @@ -81,6 +63,8 @@ void MpiWorld::sendRemoteMpiMessage( const std::shared_ptr& msg) { // Get the index for the rank-host pair + // Note - message endpoints are identified by a (localRank, remoteRank) + // pair, not a (sendRank, recvRank) one int index = getIndexForRanks(sendRank, recvRank); if (mpiMessageEndpoints.empty() || mpiMessageEndpoints[index] == nullptr) { @@ -95,10 +79,12 @@ std::shared_ptr MpiWorld::recvRemoteMpiMessage( int recvRank) { // Get the index for the rank-host pair - int index = getIndexForRanks(sendRank, recvRank); + // Note - message endpoints are identified by a (localRank, remoteRank) + // pair, not a (sendRank, recvRank) one + int index = getIndexForRanks(recvRank, sendRank); if (mpiMessageEndpoints.empty() || mpiMessageEndpoints[index] == nullptr) { - initRemoteMpiEndpoint(sendRank, recvRank); + initRemoteMpiEndpoint(recvRank, sendRank); } return mpiMessageEndpoints[index]->recvMpiMessage(); @@ -160,7 +146,14 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize) // Register hosts to rank mappings on this host faabric::MpiHostsToRanksMessage hostRankMsg; *hostRankMsg.mutable_hosts() = { executedAt.begin(), executedAt.end() }; - setAllRankHosts(hostRankMsg); + + // Prepare the base port for each rank + std::vector basePortForRank = initLocalBasePorts(executedAt); + *hostRankMsg.mutable_baseports() = { basePortForRank.begin(), + basePortForRank.end() }; + + // Register hosts to rank mappins on this host + setAllRankHostsPorts(hostRankMsg); // Set up a list of hosts to broadcast to (excluding this host) std::set hosts(executedAt.begin(), executedAt.end()); @@ -249,7 +242,7 @@ void MpiWorld::initialiseFromMsg(const faabric::Message& msg, bool forceLocal) // Block until we receive faabric::MpiHostsToRanksMessage hostRankMsg = faabric::transport::recvMpiHostRankMsg(); - setAllRankHosts(hostRankMsg); + setAllRankHostsPorts(hostRankMsg); // Initialise the memory queues for message reception initLocalQueues(); @@ -269,16 +262,51 @@ std::string MpiWorld::getHostForRank(int rank) return host; } +// Returns a pair (sendPort, recvPort) +// To assign the send and recv ports, we follow a protocol establishing: +// 1) Port range (offset) corresponding to the world that receives +// 2) Within a world's port range, port corresponding to the outcome of +// getIndexForRanks(localRank, remoteRank) Where local and remote are +// relative to the world whose port range we are in +std::pair MpiWorld::getPortForRanks(int localRank, int remoteRank) +{ + std::pair sendRecvPortPair; + + // Get base port for local and remote worlds + int localBasePort = basePorts[localRank]; + int remoteBasePort = basePorts[remoteRank]; + assert(localBasePort != remoteBasePort); + + // Assign send port + // 1) Port range corresponding to remote world, as they are receiving + // 2) Index switching localRank and remoteRank, as remote rank is "local" + // to the remote world + sendRecvPortPair.first = + remoteBasePort + getIndexForRanks(remoteRank, localRank); + + // Assign recv port + // 1) Port range corresponding to our world, as we are the one's receiving + // 2) Port using our local rank as `localRank`, as we are in the local + // offset + sendRecvPortPair.second = + localBasePort + getIndexForRanks(localRank, remoteRank); + + return sendRecvPortPair; +} + // Prepare the host-rank map with a vector containing _all_ ranks // Note - this method should be called by only one rank. This is enforced in // the world registry -void MpiWorld::setAllRankHosts(const faabric::MpiHostsToRanksMessage& msg) +void MpiWorld::setAllRankHostsPorts(const faabric::MpiHostsToRanksMessage& msg) { // Assert we are only setting the values once assert(rankHosts.size() == 0); + assert(basePorts.size() == 0); assert(msg.hosts().size() == size); + assert(msg.baseports().size() == size); rankHosts = { msg.hosts().begin(), msg.hosts().end() }; + basePorts = { msg.baseports().begin(), msg.baseports().end() }; } void MpiWorld::getCartesianRank(int rank, @@ -448,15 +476,6 @@ int MpiWorld::irecv(int sendRank, return requestId; } -int MpiWorld::getMpiPort(int sendRank, int recvRank) -{ - // TODO - get port in a multi-tenant-safe manner - int basePort = MPI_PORT; - int rankOffset = sendRank * size + recvRank; - - return basePort + rankOffset; -} - void MpiWorld::send(int sendRank, int recvRank, const uint8_t* buffer, @@ -1204,6 +1223,34 @@ void MpiWorld::initLocalQueues() } } +// Here we rely on the scheduler returning a list of hosts where equal +// hosts are always contiguous with the exception of the master host +// (thisHost) which may appear repeated at the end if the system is +// overloaded. +std::vector MpiWorld::initLocalBasePorts( + const std::vector& executedAt) +{ + std::vector basePortForRank; + basePortForRank.reserve(size); + + std::string lastHost = thisHost; + int lastPort = MPI_PORT; + for (const auto& host : executedAt) { + if (host == thisHost) { + basePortForRank.push_back(MPI_PORT); + } else if (host == lastHost) { + basePortForRank.push_back(lastPort); + } else { + lastHost = host; + lastPort += size * size; + basePortForRank.push_back(lastPort); + } + } + + assert(basePortForRank.size() == size); + return basePortForRank; +} + std::shared_ptr MpiWorld::recvBatchReturnLast(int sendRank, int recvRank, int batchSize) { diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 5f87412b3..1cb71a871 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -279,11 +279,6 @@ RecvMessageEndpoint::RecvMessageEndpoint(int portIn) : MessageEndpoint(ANY_HOST, portIn) {} -RecvMessageEndpoint::RecvMessageEndpoint(int portIn, - const std::string& overrideHost) - : MessageEndpoint(overrideHost, portIn) -{} - void RecvMessageEndpoint::open(MessageContext& context) { SPDLOG_TRACE( diff --git a/src/transport/MpiMessageEndpoint.cpp b/src/transport/MpiMessageEndpoint.cpp index 4aebbefcd..67be90801 100644 --- a/src/transport/MpiMessageEndpoint.cpp +++ b/src/transport/MpiMessageEndpoint.cpp @@ -37,10 +37,10 @@ MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, int portIn) } MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, - int portIn, - const std::string& overrideRecvHost) - : sendMessageEndpoint(hostIn, portIn) - , recvMessageEndpoint(portIn, overrideRecvHost) + int sendPort, + int recvPort) + : sendMessageEndpoint(hostIn, sendPort) + , recvMessageEndpoint(recvPort) { sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); diff --git a/tests/test/transport/test_mpi_message_endpoint.cpp b/tests/test/transport/test_mpi_message_endpoint.cpp index a5b6a4ced..a27ae9bf7 100644 --- a/tests/test/transport/test_mpi_message_endpoint.cpp +++ b/tests/test/transport/test_mpi_message_endpoint.cpp @@ -31,8 +31,8 @@ TEST_CASE_METHOD(MessageContextFixture, "[transport]") { std::string thisHost = faabric::util::getSystemConfig().endpointHost; - MpiMessageEndpoint sendEndpoint(LOCALHOST, 9999, thisHost); - MpiMessageEndpoint recvEndpoint(thisHost, 9999, LOCALHOST); + MpiMessageEndpoint sendEndpoint(LOCALHOST, 9999, 9998); + MpiMessageEndpoint recvEndpoint(thisHost, 9998, 9999); std::shared_ptr expected = std::make_shared(); From 142284392aad21382d9a686232d69fd20c4b9e35 Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Wed, 16 Jun 2021 11:03:27 +0000 Subject: [PATCH 8/8] adding more tests --- .../test/scheduler/test_remote_mpi_worlds.cpp | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index 5436f2eb9..cbe84d573 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -73,6 +73,9 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") // Send a message that should get sent to this host remoteWorld.send( rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size()); + + usleep(1000 * 500); + remoteWorld.destroy(); }); @@ -94,6 +97,65 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") localWorld.destroy(); } +TEST_CASE_METHOD(RemoteMpiTestFixture, + "Test send and recv across hosts", + "[mpi]") +{ + // Register two ranks (one on each host) + this->setWorldsSizes(2, 1, 1); + int rankA = 0; + int rankB = 1; + std::vector messageData = { 0, 1, 2 }; + std::vector messageData2 = { 3, 4, 5 }; + + // Init worlds + MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + + std::thread senderThread([this, rankA, rankB, &messageData, &messageData2] { + remoteWorld.initialiseFromMsg(msg); + + // Send a message that should get sent to this host + remoteWorld.send( + rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size()); + + // Now recv + auto buffer = new int[messageData2.size()]; + remoteWorld.recv(rankA, + rankB, + BYTES(buffer), + MPI_INT, + messageData2.size(), + MPI_STATUS_IGNORE); + std::vector actual(buffer, buffer + messageData2.size()); + REQUIRE(actual == messageData2); + + usleep(1000 * 500); + + remoteWorld.destroy(); + }); + + // Receive the message for the given rank + MPI_Status status{}; + auto buffer = new int[messageData.size()]; + localWorld.recv( + rankB, rankA, BYTES(buffer), MPI_INT, messageData.size(), &status); + std::vector actual(buffer, buffer + messageData.size()); + REQUIRE(actual == messageData); + + // Now send a message + localWorld.send( + rankA, rankB, BYTES(messageData2.data()), MPI_INT, messageData2.size()); + + REQUIRE(status.MPI_SOURCE == rankB); + REQUIRE(status.MPI_ERROR == MPI_SUCCESS); + REQUIRE(status.bytesSize == messageData.size() * sizeof(int)); + + // Destroy worlds + senderThread.join(); + localWorld.destroy(); +} + TEST_CASE_METHOD(RemoteMpiTestFixture, "Test sending many messages across host", "[mpi]") @@ -114,6 +176,9 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, for (int i = 0; i < numMessages; i++) { remoteWorld.send(rankB, rankA, BYTES(&i), MPI_INT, 1); } + + usleep(1000 * 500); + remoteWorld.destroy(); }); @@ -165,6 +230,8 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, assert(actual == messageData); } + usleep(1000 * 500); + remoteWorld.destroy(); }); @@ -237,6 +304,8 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, nPerRank); assert(actual == std::vector({ 12, 13, 14, 15 })); + usleep(1000 * 500); + remoteWorld.destroy(); }); @@ -324,6 +393,8 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, nPerRank); } + usleep(1000 * 500); + remoteWorld.destroy(); }); @@ -387,6 +458,8 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MPI_INT, messageData.size()); + usleep(1000 * 500); + remoteWorld.destroy(); }); @@ -440,6 +513,8 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, remoteWorld.send(sendRank, recvRank, BYTES(&i), MPI_INT, 1); } + usleep(1000 * 500); + remoteWorld.destroy(); }); @@ -512,6 +587,8 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MPI_STATUS_IGNORE); } + usleep(1000 * 500); + remoteWorld.destroy(); });