From f59c0565a14cf1dd96f2b517640aca38b24643bc Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Tue, 15 Jun 2021 15:02:30 +0000 Subject: [PATCH] pr comments --- include/faabric/scheduler/MpiMessageBuffer.h | 45 +++++++---- include/faabric/scheduler/MpiWorld.h | 4 +- src/scheduler/MpiMessageBuffer.cpp | 56 +++++++------ src/scheduler/MpiWorld.cpp | 79 ++++++++++--------- .../scheduler/test_mpi_message_buffer.cpp | 19 +++-- tests/test/scheduler/test_mpi_world.cpp | 57 +++++++++++++ 6 files changed, 169 insertions(+), 91 deletions(-) diff --git a/include/faabric/scheduler/MpiMessageBuffer.h b/include/faabric/scheduler/MpiMessageBuffer.h index 9715ee346..59cc39120 100644 --- a/include/faabric/scheduler/MpiMessageBuffer.h +++ b/include/faabric/scheduler/MpiMessageBuffer.h @@ -21,16 +21,25 @@ class MpiMessageBuffer * buffer. Note that the message field will point to null if unacknowleged * or to a valid message otherwise. */ - struct Arguments + class PendingAsyncMpiMessage { - int requestId; - std::shared_ptr msg; - int sendRank; - int recvRank; - uint8_t* buffer; - faabric_datatype_t* dataType; - int count; - faabric::MPIMessage::MPIMessageType messageType; + public: + int requestId = -1; + std::shared_ptr msg = nullptr; + int sendRank = -1; + int recvRank = -1; + uint8_t* buffer = nullptr; + faabric_datatype_t* dataType = nullptr; + int count = -1; + faabric::MPIMessage::MPIMessageType messageType = + faabric::MPIMessage::NORMAL; + + bool isAcknowledged() { return msg != nullptr; } + + void acknowledge(std::shared_ptr msgIn) + { + msg = msgIn; + } }; /* Interface to query the buffer size */ @@ -41,31 +50,33 @@ class MpiMessageBuffer /* Interface to add and delete messages to the buffer */ - void addMessage(Arguments arg); + void addMessage(PendingAsyncMpiMessage msg); - void deleteMessage(const std::list::iterator& argIt); + void deleteMessage( + const std::list::iterator& msgIt); /* Interface to get a pointer to a message in the MMB */ // Pointer to a message given its request id - std::list::iterator getRequestArguments(int requestId); + std::list::iterator getRequestPendingMsg( + int requestId); // Pointer to the first null-pointing (unacknowleged) message - std::list::iterator getFirstNullMsg(); + std::list::iterator getFirstNullMsg(); /* Interface to ask for the number of unacknowleged messages */ // Unacknowledged messages until an iterator (used in await) int getTotalUnackedMessagesUntil( - const std::list::iterator& argIt); + const std::list::iterator& msgIt); // Unacknowledged messages in the whole buffer (used in recv) int getTotalUnackedMessages(); private: - std::list args; + std::list pendingMsgs; - std::list::iterator getFirstNullMsgUntil( - const std::list::iterator& argIt); + std::list::iterator getFirstNullMsgUntil( + const std::list::iterator& msgIt); }; } diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index f3b0d00c1..256d23f97 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -215,8 +215,8 @@ class MpiWorld void closeMpiMessageEndpoints(); // Support for asyncrhonous communications - std::shared_ptr - getUnackedMessageBuffer(int sendRank, int recvRank); + std::shared_ptr getUnackedMessageBuffer(int sendRank, + int recvRank); std::shared_ptr recvBatchReturnLast(int sendRank, int recvRank, int batchSize = 0); diff --git a/src/scheduler/MpiMessageBuffer.cpp b/src/scheduler/MpiMessageBuffer.cpp index 4ddd672b2..8b2445e40 100644 --- a/src/scheduler/MpiMessageBuffer.cpp +++ b/src/scheduler/MpiMessageBuffer.cpp @@ -2,68 +2,72 @@ #include namespace faabric::scheduler { -typedef std::list::iterator ArgListIterator; +typedef std::list::iterator + MpiMessageIterator; bool MpiMessageBuffer::isEmpty() { - return args.empty(); + return pendingMsgs.empty(); } int MpiMessageBuffer::size() { - return args.size(); + return pendingMsgs.size(); } -void MpiMessageBuffer::addMessage(Arguments arg) +void MpiMessageBuffer::addMessage(PendingAsyncMpiMessage msg) { - args.push_back(arg); + pendingMsgs.push_back(msg); } -void MpiMessageBuffer::deleteMessage(const ArgListIterator& argIt) +void MpiMessageBuffer::deleteMessage(const MpiMessageIterator& msgIt) { - args.erase(argIt); + pendingMsgs.erase(msgIt); } -ArgListIterator MpiMessageBuffer::getRequestArguments(int requestId) +MpiMessageIterator MpiMessageBuffer::getRequestPendingMsg(int requestId) { // The request id must be in the MMB, as an irecv must happen before an // await - ArgListIterator argIt = - std::find_if(args.begin(), args.end(), [requestId](Arguments args) { - return args.requestId == requestId; - }); + MpiMessageIterator msgIt = + std::find_if(pendingMsgs.begin(), + pendingMsgs.end(), + [requestId](PendingAsyncMpiMessage pendingMsg) { + return pendingMsg.requestId == requestId; + }); // If it's not there, error out - if (argIt == args.end()) { + if (msgIt == pendingMsgs.end()) { SPDLOG_ERROR("Asynchronous request id not in buffer: {}", requestId); throw std::runtime_error("Async request not in buffer"); } - return argIt; + return msgIt; } -ArgListIterator MpiMessageBuffer::getFirstNullMsgUntil( - const ArgListIterator& argItEnd) +MpiMessageIterator MpiMessageBuffer::getFirstNullMsgUntil( + const MpiMessageIterator& msgItEnd) { - return std::find_if(args.begin(), argItEnd, [](Arguments args) { - return args.msg == nullptr; - }); + return std::find_if( + pendingMsgs.begin(), msgItEnd, [](PendingAsyncMpiMessage pendingMsg) { + return pendingMsg.msg == nullptr; + }); } -ArgListIterator MpiMessageBuffer::getFirstNullMsg() +MpiMessageIterator MpiMessageBuffer::getFirstNullMsg() { - return getFirstNullMsgUntil(args.end()); + return getFirstNullMsgUntil(pendingMsgs.end()); } int MpiMessageBuffer::getTotalUnackedMessagesUntil( - const ArgListIterator& argItEnd) + const MpiMessageIterator& msgItEnd) { - ArgListIterator firstNull = getFirstNullMsgUntil(argItEnd); - return std::distance(firstNull, argItEnd); + MpiMessageIterator firstNull = getFirstNullMsgUntil(msgItEnd); + return std::distance(firstNull, msgItEnd); } int MpiMessageBuffer::getTotalUnackedMessages() { - ArgListIterator firstNull = getFirstNullMsg(); - return std::distance(firstNull, args.end()); + MpiMessageIterator firstNull = getFirstNullMsg(); + return std::distance(firstNull, pendingMsgs.end()); } } diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 42ea0b8e3..f9716001a 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -207,11 +207,19 @@ void MpiWorld::destroy() // Request to rank map should be empty if (!reqIdToRanks.empty()) { SPDLOG_ERROR( - "Destroying the MPI world with {} outstanding async requests", + "Destroying the MPI world with {} outstanding irecv requests", reqIdToRanks.size()); throw std::runtime_error("Destroying world with outstanding requests"); } + // iSend set should be empty + if (!iSendRequests.empty()) { + SPDLOG_ERROR( + "Destroying the MPI world with {} outstanding isend requests", + iSendRequests.size()); + throw std::runtime_error("Destroying world with outstanding requests"); + } + // Destroy once per host the shared resources if (!isDestroyed.test_and_set()) { // Wait (forever) until all ranks are done consuming their queues to @@ -423,15 +431,19 @@ int MpiWorld::irecv(int sendRank, int requestId = (int)faabric::util::generateGid(); reqIdToRanks.try_emplace(requestId, sendRank, recvRank); - // Enqueue a request with a null-pointing message (i.e. unacknowleged) and - // the generated request id - faabric::scheduler::MpiMessageBuffer::Arguments args = { - requestId, nullptr, sendRank, recvRank, - buffer, dataType, count, messageType - }; + // Enqueue an unacknowleged request (no message) + faabric::scheduler::MpiMessageBuffer::PendingAsyncMpiMessage pendingMsg; + pendingMsg.requestId = requestId; + pendingMsg.sendRank = sendRank; + pendingMsg.recvRank = recvRank; + pendingMsg.buffer = buffer; + pendingMsg.dataType = dataType; + pendingMsg.count = count; + pendingMsg.messageType = messageType; + assert(!pendingMsg.isAcknowledged()); auto umb = getUnackedMessageBuffer(sendRank, recvRank); - umb->addMessage(args); + umb->addMessage(pendingMsg); return requestId; } @@ -806,29 +818,29 @@ void MpiWorld::awaitAsyncRequest(int requestId) std::shared_ptr umb = getUnackedMessageBuffer(sendRank, recvRank); - std::list::iterator argsIndex = - umb->getRequestArguments(requestId); + std::list::iterator msgIt = + umb->getRequestPendingMsg(requestId); std::shared_ptr m; - if (argsIndex->msg != nullptr) { + if (msgIt->msg != nullptr) { // This id has already been acknowledged by a recv call, so do the recv - m = argsIndex->msg; + m = msgIt->msg; } else { // We need to acknowledge all messages not acknowledged from the // begining until us m = recvBatchReturnLast( - sendRank, recvRank, umb->getTotalUnackedMessagesUntil(argsIndex) + 1); + sendRank, recvRank, umb->getTotalUnackedMessagesUntil(msgIt) + 1); } doRecv(m, - argsIndex->buffer, - argsIndex->dataType, - argsIndex->count, + msgIt->buffer, + msgIt->dataType, + msgIt->count, MPI_STATUS_IGNORE, - argsIndex->messageType); + msgIt->messageType); // Remove the acknowledged indexes from the UMB - umb->deleteMessage(argsIndex); + umb->deleteMessage(msgIt); } void MpiWorld::reduce(int sendRank, @@ -1213,47 +1225,42 @@ MpiWorld::recvBatchReturnLast(int sendRank, int recvRank, int batchSize) // Recv message: first we receive all messages for which there is an id // in the unacknowleged buffer but no msg. Note that these messages // (batchSize - 1) were `irecv`-ed before ours. - std::shared_ptr m; - auto argsIt = umb->getFirstNullMsg(); + std::shared_ptr ourMsg; + auto msgIt = umb->getFirstNullMsg(); if (isLocal) { // First receive messages that happened before us for (int i = 0; i < batchSize - 1; i++) { SPDLOG_TRACE("MPI - pending recv {} -> {}", sendRank, recvRank); - auto _m = getLocalQueue(sendRank, recvRank)->dequeue(); - - assert(_m != nullptr); - assert(argsIt->msg == nullptr); + auto pendingMsg = getLocalQueue(sendRank, recvRank)->dequeue(); // Put the unacked message in the UMB - argsIt->msg = _m; - argsIt++; + assert(!msgIt->isAcknowledged()); + msgIt->acknowledge(pendingMsg); + msgIt++; } // Finally receive the message corresponding to us SPDLOG_TRACE("MPI - recv {} -> {}", sendRank, recvRank); - m = getLocalQueue(sendRank, recvRank)->dequeue(); + ourMsg = getLocalQueue(sendRank, recvRank)->dequeue(); } else { // First receive messages that happened before us for (int i = 0; i < batchSize - 1; i++) { SPDLOG_TRACE( "MPI - pending remote recv {} -> {}", sendRank, recvRank); - auto _m = recvRemoteMpiMessage(sendRank, recvRank); - - assert(_m != nullptr); - assert(argsIt->msg == nullptr); + auto pendingMsg = recvRemoteMpiMessage(sendRank, recvRank); // Put the unacked message in the UMB - argsIt->msg = _m; - argsIt++; + assert(!msgIt->isAcknowledged()); + msgIt->acknowledge(pendingMsg); + msgIt++; } // Finally receive the message corresponding to us SPDLOG_TRACE("MPI - recv remote {} -> {}", sendRank, recvRank); - m = recvRemoteMpiMessage(sendRank, recvRank); + ourMsg = recvRemoteMpiMessage(sendRank, recvRank); } - assert(m != nullptr); - return m; + return ourMsg; } int MpiWorld::getIndexForRanks(int sendRank, int recvRank) diff --git a/tests/test/scheduler/test_mpi_message_buffer.cpp b/tests/test/scheduler/test_mpi_message_buffer.cpp index 16665771f..40a845c0e 100644 --- a/tests/test/scheduler/test_mpi_message_buffer.cpp +++ b/tests/test/scheduler/test_mpi_message_buffer.cpp @@ -6,8 +6,9 @@ using namespace faabric::scheduler; -MpiMessageBuffer::Arguments genRandomArguments(bool nullMsg = true, - int overrideRequestId = -1) +MpiMessageBuffer::PendingAsyncMpiMessage genRandomArguments( + bool nullMsg = true, + int overrideRequestId = -1) { int requestId; if (overrideRequestId != -1) { @@ -16,16 +17,14 @@ MpiMessageBuffer::Arguments genRandomArguments(bool nullMsg = true, requestId = static_cast(faabric::util::generateGid()); } - MpiMessageBuffer::Arguments args = { - requestId, nullptr, 0, 1, - nullptr, MPI_INT, 0, faabric::MPIMessage::NORMAL - }; + MpiMessageBuffer::PendingAsyncMpiMessage pendingMsg; + pendingMsg.requestId = requestId; if (!nullMsg) { - args.msg = std::make_shared(); + pendingMsg.msg = std::make_shared(); } - return args; + return pendingMsg; } namespace tests { @@ -59,7 +58,7 @@ TEST_CASE("Test getting an iterator from a request id", "[mpi]") int requestId = 1337; mmb.addMessage(genRandomArguments(true, requestId)); - auto it = mmb.getRequestArguments(requestId); + auto it = mmb.getRequestPendingMsg(requestId); REQUIRE(it->requestId == requestId); } @@ -111,7 +110,7 @@ TEST_CASE("Test getting total unacked messages in message buffer range", mmb.addMessage(genRandomArguments(true, requestId)); // Get an iterator to our second null message - auto it = mmb.getRequestArguments(requestId); + auto it = mmb.getRequestPendingMsg(requestId); // Check that we have only one unacked message until the iterator REQUIRE(mmb.getTotalUnackedMessagesUntil(it) == 1); diff --git a/tests/test/scheduler/test_mpi_world.cpp b/tests/test/scheduler/test_mpi_world.cpp index 3d626e10e..a4a3b0391 100644 --- a/tests/test/scheduler/test_mpi_world.cpp +++ b/tests/test/scheduler/test_mpi_world.cpp @@ -1202,4 +1202,61 @@ TEST_CASE_METHOD(MpiBaseTestFixture, "Test all-to-all", "[mpi]") world.destroy(); } + +TEST_CASE_METHOD(MpiTestFixture, + "Test can't destroy world with outstanding requests", + "[mpi]") +{ + int rankA = 0; + int rankB = 1; + int data = 9; + int actual = -1; + + SECTION("Outstanding irecv") + { + world.send(rankA, rankB, BYTES(&data), MPI_INT, 1); + int recvId = world.irecv(rankA, rankB, BYTES(&actual), MPI_INT, 1); + + REQUIRE_THROWS(world.destroy()); + + world.awaitAsyncRequest(recvId); + REQUIRE(actual == data); + } + + SECTION("Outstanding acknowledged irecv") + { + int data2 = 14; + int actual2 = -1; + + world.send(rankA, rankB, BYTES(&data), MPI_INT, 1); + world.send(rankA, rankB, BYTES(&data2), MPI_INT, 1); + int recvId = world.irecv(rankA, rankB, BYTES(&actual), MPI_INT, 1); + int recvId2 = world.irecv(rankA, rankB, BYTES(&actual2), MPI_INT, 1); + + REQUIRE_THROWS(world.destroy()); + + // Await for the second request, which will acknowledge the first one + // but not remove it from the pending message buffer + world.awaitAsyncRequest(recvId2); + + REQUIRE_THROWS(world.destroy()); + + // Await for the first one + world.awaitAsyncRequest(recvId); + + REQUIRE(actual == data); + REQUIRE(actual2 == data2); + } + + SECTION("Outstanding isend") + { + int sendId = world.isend(rankA, rankB, BYTES(&data), MPI_INT, 1); + world.recv(rankA, rankB, BYTES(&actual), MPI_INT, 1, MPI_STATUS_IGNORE); + + REQUIRE_THROWS(world.destroy()); + + world.awaitAsyncRequest(sendId); + REQUIRE(actual == data); + } +} }