diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index 39a788f4e..9150ff287 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -220,12 +220,11 @@ class MpiWorld std::string function; std::shared_ptr stateKV; - std::unordered_map rankHostMap; + std::vector rankHosts; std::unordered_map windowPointerMap; - std::unordered_map> - localQueueMap; + std::vector> localQueues; std::shared_ptr threadPool; int getMpiThreadPoolSize(); @@ -235,8 +234,10 @@ class MpiWorld faabric::scheduler::FunctionCallClient& getFunctionCallClient( const std::string& otherHost); - void checkRankOnThisHost(int rank); - void closeThreadLocalClients(); + + int getIndexForRanks(int sendRank, int recvRank); + + void initLocalQueues(); }; } diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 3054490c8..8a438eb2c 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -123,6 +123,9 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize) for (const auto& h : hosts) { faabric::transport::sendMpiHostRankMsg(h, hostRankMsg); } + + // Initialise the memory queues for message reception + initLocalQueues(); } void MpiWorld::destroy() @@ -138,10 +141,12 @@ void MpiWorld::destroy() // clear them. // Note - this means that an application with outstanding messages, i.e. // send without recv, will block forever. - for (auto& k : localQueueMap) { - k.second->waitToDrain(-1); + for (auto& q : localQueues) { + if (q != nullptr) { + q->waitToDrain(-1); + } } - localQueueMap.clear(); + localQueues.clear(); } } @@ -186,44 +191,46 @@ void MpiWorld::initialiseFromMsg(const faabric::Message& msg, bool forceLocal) // Sometimes for testing purposes we may want to initialise a world in the // _same_ host we have created one (note that this would never happen in - // reality). If so, we skip the rank-host map broadcasting. + // reality). If so, we skip initialising resources already initialised if (!forceLocal) { // Block until we receive faabric::MpiHostsToRanksMessage hostRankMsg = faabric::transport::recvMpiHostRankMsg(); setAllRankHosts(hostRankMsg); + + // Initialise the memory queues for message reception + initLocalQueues(); } } std::string MpiWorld::getHostForRank(int rank) { - if (rankHostMap.find(rank) == rankHostMap.end()) { - logger->error("No known host for rank {}", rank); - throw std::runtime_error("No known host for rank"); + assert(rankHosts.size() == size); + + if (rank >= size) { + throw std::runtime_error( + fmt::format("Rank bigger than world size ({} > {})", rank, size)); } - { - faabric::util::SharedLock lock(worldMutex); - return rankHostMap[rank]; + std::string host = rankHosts[rank]; + if (host.empty()) { + throw std::runtime_error( + fmt::format("No host found for rank {}", rank)); } + + return host; } // Prepare the host-rank map with a vector containing _all_ ranks +// Note - this method should be called by only one rank. This is enforced in +// the world registry void MpiWorld::setAllRankHosts(const faabric::MpiHostsToRanksMessage& msg) { + // Assert we are only setting the values once + assert(rankHosts.size() == 0); + assert(msg.hosts().size() == size); - faabric::util::FullLock lock(worldMutex); - for (int i = 0; i < size; i++) { - auto it = rankHostMap.try_emplace(i, msg.hosts().at(i)); - if (!it.second) { - logger->error("Tried to map host ({}) to rank ({}), but rank was " - "already mapped to host ({})", - msg.hosts().at(i), - i + 1, - rankHostMap[i]); - throw std::runtime_error("Rank already mapped to host"); - } - } + rankHosts = { msg.hosts().begin(), msg.hosts().end() }; } void MpiWorld::getCartesianRank(int rank, @@ -413,10 +420,9 @@ void MpiWorld::send(int sendRank, int count, faabric::MPIMessage::MPIMessageType messageType) { - if (recvRank > this->size - 1) { - throw std::runtime_error(fmt::format( - "Rank {} bigger than world size {}", recvRank, this->size)); - } + // Work out whether the message is sent locally or to another host + const std::string otherHost = getHostForRank(recvRank); + bool isLocal = otherHost == thisHost; // Generate a message ID int msgId = (int)faabric::util::generateGid(); @@ -431,10 +437,6 @@ void MpiWorld::send(int sendRank, m->set_count(count); m->set_messagetype(messageType); - // Work out whether the message is sent locally or to another host - const std::string otherHost = getHostForRank(recvRank); - bool isLocal = otherHost == thisHost; - // Set up message data if (count > 0 && buffer != nullptr) { m->set_buffer(buffer, dataType->size * count); @@ -468,21 +470,10 @@ void MpiWorld::recv(int sendRank, std::shared_ptr m = getLocalQueue(sendRank, recvRank)->dequeue(); - if (messageType != m->messagetype()) { - logger->error( - "Message types mismatched on {}->{} (expected={}, got={})", - sendRank, - recvRank, - messageType, - m->messagetype()); - throw std::runtime_error("Mismatched message types"); - } - - if (m->count() > count) { - logger->error( - "Message too long for buffer (msg={}, buffer={})", m->count(), count); - throw std::runtime_error("Message too long"); - } + // Assert message integrity + // Note - this checks won't happen in Release builds + assert(m->messagetype() == messageType); + assert(m->count() <= count); // TODO - avoid copy here // Copy message data @@ -1119,23 +1110,34 @@ void MpiWorld::enqueueMessage(faabric::MPIMessage& msg) std::shared_ptr MpiWorld::getLocalQueue(int sendRank, int recvRank) { - checkRankOnThisHost(recvRank); + assert(getHostForRank(recvRank) == thisHost); + assert(localQueues.size() == size * size); - std::string key = std::to_string(sendRank) + "_" + std::to_string(recvRank); - if (localQueueMap.find(key) == localQueueMap.end()) { - faabric::util::FullLock lock(worldMutex); + return localQueues[getIndexForRanks(sendRank, recvRank)]; +} - if (localQueueMap.find(key) == localQueueMap.end()) { - auto mq = new InMemoryMpiQueue(); - localQueueMap.emplace( - std::pair(key, mq)); +// We pre-allocate all _potentially_ necessary queues in advance. Queues are +// necessary to _receive_ messages, thus we initialise all queues whose +// corresponding receiver is local to this host +// Note - the queues themselves perform concurrency control +void MpiWorld::initLocalQueues() +{ + // Assert we only allocate queues once + assert(localQueues.size() == 0); + localQueues.resize(size * size); + for (int recvRank = 0; recvRank < size; recvRank++) { + if (getHostForRank(recvRank) == thisHost) { + for (int sendRank = 0; sendRank < size; sendRank++) { + localQueues[getIndexForRanks(sendRank, recvRank)] = + std::make_shared(); + } } } +} - { - faabric::util::SharedLock lock(worldMutex); - return localQueueMap[key]; - } +int MpiWorld::getIndexForRanks(int sendRank, int recvRank) +{ + return sendRank * size + recvRank; } void MpiWorld::rmaGet(int sendRank, @@ -1227,21 +1229,6 @@ long MpiWorld::getLocalQueueSize(int sendRank, int recvRank) return queue->size(); } -void MpiWorld::checkRankOnThisHost(int rank) -{ - // Check if we know about this rank on this host - if (rankHostMap.count(rank) == 0) { - logger->error("No mapping found for rank {} on this host", rank); - throw std::runtime_error("No mapping found for rank"); - } else if (rankHostMap[rank] != thisHost) { - logger->error("Trying to access rank {} on {} but it's on {}", - rank, - thisHost, - rankHostMap[rank]); - throw std::runtime_error("Accessing in-memory queue for remote rank"); - } -} - void MpiWorld::createWindow(const int winRank, const int winSize, uint8_t* windowPtr) diff --git a/tests/test/scheduler/test_mpi_world.cpp b/tests/test/scheduler/test_mpi_world.cpp index 4bf4b3fa8..1076bd7aa 100644 --- a/tests/test/scheduler/test_mpi_world.cpp +++ b/tests/test/scheduler/test_mpi_world.cpp @@ -301,19 +301,6 @@ TEST_CASE("Test send and recv on same host", "[mpi]") REQUIRE(status.bytesSize == messageData.size() * sizeof(int)); } - SECTION("Test recv with type missmatch") - { - // Receive a message from a different type - auto buffer = new int[messageData.size()]; - REQUIRE_THROWS(world.recv(rankA1, - rankA2, - BYTES(buffer), - MPI_INT, - messageData.size(), - nullptr, - faabric::MPIMessage::SENDRECV)); - } - tearDown({ &world }); } @@ -665,51 +652,6 @@ TEST_CASE("Test probe", "[mpi]") tearDown({ &world }); } -TEST_CASE("Test can't get in-memory queue for non-local ranks", "[mpi]") -{ - cleanFaabric(); - - std::string otherHost = LOCALHOST; - - auto& sch = faabric::scheduler::getScheduler(); - - // Force the scheduler to initialise a world in the remote host by setting - // a worldSize bigger than the slots available locally - int worldSize = 4; - faabric::HostResources localResources; - localResources.set_slots(2); - localResources.set_usedslots(1); - faabric::HostResources otherResources; - otherResources.set_slots(2); - - // Set up a remote host - sch.addHostToGlobalSet(otherHost); - - // Mock everything to make sure the other host has resources as well - faabric::util::setMockMode(true); - sch.setThisHostResources(localResources); - faabric::scheduler::queueResourceResponse(otherHost, otherResources); - - faabric::Message msg = faabric::util::messageFactory(user, func); - msg.set_mpiworldsize(worldSize); - scheduler::MpiWorld worldA; - worldA.create(msg, worldId, worldSize); - - scheduler::MpiWorld worldB; - worldB.overrideHost(otherHost); - worldB.initialiseFromMsg(msg); - - // Check that we can't access rank on another host locally - REQUIRE_THROWS(worldA.getLocalQueue(0, 2)); - - // Double check even when we've retrieved the rank - REQUIRE(worldA.getHostForRank(2) == otherHost); - REQUIRE_THROWS(worldA.getLocalQueue(0, 2)); - - faabric::util::setMockMode(false); - tearDown({ &worldA, &worldB }); -} - TEST_CASE("Check sending to invalid rank", "[mpi]") { cleanFaabric();