diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index 7e9a4c522..b56fae84b 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -224,15 +224,16 @@ class MpiWorld /* MPI internal messaging layer */ // Track at which host each rank lives - std::vector rankHosts; + std::vector hostForRank; int getIndexForRanks(int sendRank, int recvRank); - std::vector getRanksForHost(const std::string& host); - // Track ranks that are local to this world, and local/remote leaders + // Store the ranks that live in each host + std::map> ranksForHost; + + // Track local and remote leaders. The leader is stored in the first + // position of the host to ranks map. // MPITOPTP - this information exists in the broker int localLeader = -1; - std::vector localRanks; - std::vector remoteLeaders; void initLocalRemoteLeaders(); // In-memory queues for local messaging diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index d61dca062..57d019200 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -242,7 +242,7 @@ void MpiWorld::create(faabric::Message& call, int newId, int newSize) executedAt.insert(executedAt.begin(), thisHost); // Record rank-to-host mapping and base ports - rankHosts = executedAt; + hostForRank = executedAt; basePorts = initLocalBasePorts(executedAt); // Record which ranks are local to this world, and query for all leaders @@ -258,7 +258,7 @@ void MpiWorld::create(faabric::Message& call, int newId, int newSize) void MpiWorld::broadcastHostsToRanks() { // Set up a list of hosts to broadcast to (excluding this host) - std::set targetHosts(rankHosts.begin(), rankHosts.end()); + std::set targetHosts(hostForRank.begin(), hostForRank.end()); targetHosts.erase(thisHost); if (targetHosts.empty()) { @@ -268,7 +268,7 @@ void MpiWorld::broadcastHostsToRanks() // Register hosts to rank mappings on this host faabric::MpiHostsToRanksMessage hostRankMsg; - *hostRankMsg.mutable_hosts() = { rankHosts.begin(), rankHosts.end() }; + *hostRankMsg.mutable_hosts() = { hostForRank.begin(), hostForRank.end() }; // Prepare the base port for each rank *hostRankMsg.mutable_baseports() = { basePorts.begin(), basePorts.end() }; @@ -344,13 +344,13 @@ void MpiWorld::initialiseFromMsg(faabric::Message& msg) // enforced in the world registry. // Assert we are only setting the values once - assert(rankHosts.empty()); + assert(hostForRank.empty()); assert(basePorts.empty()); assert(hostRankMsg.hosts().size() == size); assert(hostRankMsg.baseports().size() == size); - rankHosts = { hostRankMsg.hosts().begin(), hostRankMsg.hosts().end() }; + hostForRank = { hostRankMsg.hosts().begin(), hostRankMsg.hosts().end() }; basePorts = { hostRankMsg.baseports().begin(), hostRankMsg.baseports().end() }; @@ -368,9 +368,9 @@ void MpiWorld::setMsgForRank(faabric::Message& msg) std::string MpiWorld::getHostForRank(int rank) { - assert(rankHosts.size() == size); + assert(hostForRank.size() == size); - std::string host = rankHosts[rank]; + std::string host = hostForRank[rank]; if (host.empty()) { throw std::runtime_error( fmt::format("No host found for rank {}", rank)); @@ -379,37 +379,29 @@ std::string MpiWorld::getHostForRank(int rank) return host; } -std::vector MpiWorld::getRanksForHost(const std::string& host) -{ - assert(rankHosts.size() == size); - - std::vector ranksForHost; - for (int i = 0; i < rankHosts.size(); i++) { - if (rankHosts.at(i) == host) { - ranksForHost.push_back(i); - } - } - - return ranksForHost; -} - // The local leader for an MPI world is defined as the lowest rank assigned to -// this host +// this host. For simplicity, we set the local leader to be the first element +// in the ranks to hosts map. void MpiWorld::initLocalRemoteLeaders() { - std::set uniqueHosts(rankHosts.begin(), rankHosts.end()); + // First, group the ranks per host they belong to for convinience + assert(hostForRank.size() == size); - for (const std::string& host : uniqueHosts) { - auto ranksInHost = getRanksForHost(host); - // Persist the ranks that are colocated in this host for further use - if (host == thisHost) { - localRanks = ranksInHost; - localLeader = - *std::min_element(ranksInHost.begin(), ranksInHost.end()); - } else { - remoteLeaders.push_back( - *std::min_element(ranksInHost.begin(), ranksInHost.end())); + for (int rank = 0; rank < hostForRank.size(); rank++) { + std::string host = hostForRank.at(rank); + ranksForHost[host].push_back(rank); + } + + // Second, put the local leader for each host (currently lowest rank) at the + // front. + for (auto it : ranksForHost) { + // Persist the local leader in this host for further use + if (it.first == thisHost) { + localLeader = *std::min_element(it.second.begin(), it.second.end()); } + + std::iter_swap(it.second.begin(), + std::min_element(it.second.begin(), it.second.end())); } } @@ -788,19 +780,32 @@ void MpiWorld::broadcast(int sendRank, SPDLOG_TRACE("MPI - bcast {} -> all", sendRank); if (recvRank == sendRank) { - // The sending rank sends a message to all local ranks in the broadcast, - // and all remote leaders - for (const int localRecvRank : localRanks) { - if (localRecvRank == recvRank) { - continue; + for (auto it : ranksForHost) { + if (it.first == thisHost) { + // Send message to all our local ranks besides ourselves + for (const int localRecvRank : it.second) { + if (localRecvRank == recvRank) { + continue; + } + + send(recvRank, + localRecvRank, + buffer, + dataType, + count, + messageType); + } + } else { + // Send message to the local leader of each remote host. Note + // that the local leader will then broadcast the message to its + // local ranks. + send(recvRank, + it.second.front(), + buffer, + dataType, + count, + messageType); } - - send(recvRank, localRecvRank, buffer, dataType, count, messageType); - } - - for (const int remoteRecvRank : remoteLeaders) { - send( - recvRank, remoteRecvRank, buffer, dataType, count, messageType); } } else if (recvRank == localLeader) { // If we are the local leader, first we receive the message sent by @@ -810,7 +815,7 @@ void MpiWorld::broadcast(int sendRank, // If the broadcast originated locally, we are done. If not, we now // distribute to all our local ranks if (getHostForRank(sendRank) != thisHost) { - for (const int localRecvRank : localRanks) { + for (const int localRecvRank : ranksForHost[thisHost]) { if (localRecvRank == recvRank) { continue; } @@ -906,64 +911,156 @@ void MpiWorld::gather(int sendRank, int recvCount) { checkSendRecvMatch(sendType, sendCount, recvType, recvCount); - - size_t sendOffset = sendCount * sendType->size; - size_t recvOffset = recvCount * recvType->size; + size_t sendSize = sendCount * sendType->size; + size_t recvSize = recvCount * recvType->size; + + // This method does a two-step gather where each local leader does a gather + // for its local ranks, and then the receiver and the local leaders do + // one global gather. There are five scenarios: + // 1. The rank calling gather is the receiver of the gather. This rank + // expects all its local ranks and the remote local leaders to send their + // data for gathering. + // 2. The rank calling gather is a local leader, not co-located with the + // gather receiver. This rank expects all its local ranks to send their + // data for gathering, and then sends the resulting aggregation to the + // gather receiver. + // 3. The rank calling gather is a local leader, co-located with the gather + // receiver. This rank just sends its data for gathering to the gather + // receiver. + // 4. The rank calling gather is not a local leader, not co-located with + // the gather receiver. This rank sends its data for gathering to its + // local leader. + // 5. The rank calling gather is a not a local leader, co-located with the + // gather receiver. This rank sends its data for gathering to the gather + // receiver. + + bool isGatherReceiver = sendRank == recvRank; + bool isLocalLeader = sendRank == localLeader; + bool isLocalGather = getHostForRank(recvRank) == thisHost; + + // Additionally, when sending data from gathering we must also differentiate + // between two scenarios. + // 1. Sending rank sets the MPI_IN_PLACE flag. This means the gather is part + // of an allGather, and the sending rank has allocated enough space for all + // ranks in the sending buffer. As a consequence, the to-be-gathered data + // is in the offset corresponding to the sending rank. + // 2. Sending rank does not set the MPI_IN_PLACE flag. This means that the + // sending buffer only contains the to-be-gathered data. bool isInPlace = sendBuffer == recvBuffer; + size_t sendBufferOffset = isInPlace ? sendRank * sendSize : 0; - // If we're the root, do the gathering - if (sendRank == recvRank) { + if (isGatherReceiver) { + // Scenario 1 SPDLOG_TRACE("MPI - gather all -> {}", recvRank); - // Iterate through each rank - for (int r = 0; r < size; r++) { - // Work out where in the receive buffer this rank's data goes - uint8_t* recvChunk = recvBuffer + (r * recvOffset); - - if ((r == recvRank) && isInPlace) { - // If operating in-place, data for the root rank is already in - // position - continue; - } else if (r == recvRank) { - // Copy data locally on root - std::copy(sendBuffer, sendBuffer + sendOffset, recvChunk); + for (auto it : ranksForHost) { + if (it.first == thisHost) { + // Receive from all local ranks besides ourselves + for (const int r : it.second) { + // If receiving from ourselves, but not in place, copy our + // data to the right offset + if (r == recvRank && !isInPlace) { + ::memcpy(recvBuffer + (recvRank * recvSize), + sendBuffer, + sendSize); + } else if (r != recvRank) { + recv(r, + recvRank, + recvBuffer + (r * recvSize), + recvType, + recvCount, + nullptr, + faabric::MPIMessage::GATHER); + } + } } else { - // Receive data from rank if it's not the root - recv(r, + // Receive from remote local leaders their local gathered data + auto rankData = + std::make_unique(it.second.size() * recvSize); + + recv(it.second.front(), recvRank, - recvChunk, + rankData.get(), recvType, - recvCount, + recvCount * it.second.size(), nullptr, faabric::MPIMessage::GATHER); + + // Copy each received chunk to its offset + for (int r = 0; r < it.second.size(); r++) { + ::memcpy(recvBuffer + (it.second.at(r) * recvSize), + rankData.get() + (r * recvSize), + recvSize); + } } } - } else { - if (isInPlace) { - // A non-root rank running gather "in place" happens as part of an - // allgather operation. In this case, the send and receive buffer - // are the same, and the rank is eventually expecting a broadcast of - // the gather result into this buffer. This means that this buffer - // is big enough for the whole gather result, with this rank's data - // already in place. Therefore we need to send _only_ the part of - // the send buffer relating to this rank. - const uint8_t* sendChunk = sendBuffer + (sendRank * sendOffset); - send(sendRank, - recvRank, - sendChunk, - sendType, - sendCount, - faabric::MPIMessage::GATHER); - } else { - // Normal sending - send(sendRank, - recvRank, - sendBuffer, - sendType, - sendCount, - faabric::MPIMessage::GATHER); + } else if (isLocalLeader && !isLocalGather) { + // Scenario 2 + auto rankData = + std::make_unique(ranksForHost[thisHost].size() * sendSize); + + // Gather all our local ranks data and send in a single remote message + for (int r = 0; r < ranksForHost[thisHost].size(); r++) { + if (ranksForHost[thisHost].at(r) == sendRank) { + // Receive from ourselves, just copy from/to the right offset + ::memcpy(rankData.get() + r * sendSize, + sendBuffer + sendBufferOffset, + sendSize); + } else { + // Receive from other local ranks + recv(ranksForHost[thisHost].at(r), + sendRank, + rankData.get() + r * sendSize, + sendType, + sendCount, + nullptr, + faabric::MPIMessage::GATHER); + } } + + // Send the locally-gathered data to the receiver rank + send(sendRank, + recvRank, + rankData.get(), + sendType, + sendCount * ranksForHost[thisHost].size(), + faabric::MPIMessage::GATHER); + + } else if (isLocalLeader && isLocalGather) { + // Scenario 3 + send(sendRank, + recvRank, + sendBuffer + sendBufferOffset, + sendType, + sendCount, + faabric::MPIMessage::GATHER); + } else if (!isLocalLeader && !isLocalGather) { + // Scenario 4 + send(sendRank, + localLeader, + sendBuffer + sendBufferOffset, + sendType, + sendCount, + faabric::MPIMessage::GATHER); + } else if (!isLocalLeader && isLocalGather) { + // Scenario 5 + send(sendRank, + recvRank, + sendBuffer + sendBufferOffset, + sendType, + sendCount, + faabric::MPIMessage::GATHER); + } else { + SPDLOG_ERROR("Don't know how to gather rank's data."); + SPDLOG_ERROR("- sendRank: {}\n- recvRank: {}\n- isGatherReceiver: " + "{}\n- isLocalLeader: {}\n- isLocalGather:{}", + sendRank, + recvRank, + isGatherReceiver, + isLocalLeader, + isLocalGather); + throw std::runtime_error("Don't know how to gather rank's data."); } } @@ -1065,18 +1162,8 @@ void MpiWorld::reduce(int sendRank, auto rankData = std::make_unique(bufferSize); if (sendRank == recvRank) { - // If we're the receiver of the reduce, await inputs from our local - // ranks (besides ourselves) and remote leaders SPDLOG_TRACE("MPI - reduce ({}) all -> {}", operation->id, recvRank); - // Work out the list of all the ranks we need to wait for - std::vector senderRanks = localRanks; - senderRanks.erase( - std::remove(senderRanks.begin(), senderRanks.end(), sendRank), - senderRanks.end()); - senderRanks.insert( - senderRanks.end(), remoteLeaders.begin(), remoteLeaders.end()); - // If not receiving in-place, initialize the receive buffer to the send // buffer values. This prevents issues when 0-initializing for operators // like the minimum, or product. @@ -1084,21 +1171,43 @@ void MpiWorld::reduce(int sendRank, // already done and the results are written in the recv buffer bool isInPlace = sendBuffer == recvBuffer; if (!isInPlace) { - memcpy(recvBuffer, sendBuffer, bufferSize); + ::memcpy(recvBuffer, sendBuffer, bufferSize); } - for (const int r : senderRanks) { - // Work out the data for this rank - memset(rankData.get(), 0, bufferSize); - recv(r, - recvRank, - rankData.get(), - datatype, - count, - nullptr, - faabric::MPIMessage::REDUCE); + for (auto it : ranksForHost) { + if (it.first == thisHost) { + // Reduce all data from our local ranks besides ourselves + for (const int r : it.second) { + if (r == recvRank) { + continue; + } + + memset(rankData.get(), 0, bufferSize); + recv(r, + recvRank, + rankData.get(), + datatype, + count, + nullptr, + faabric::MPIMessage::REDUCE); + + op_reduce( + operation, datatype, count, rankData.get(), recvBuffer); + } + } else { + // For remote ranks, only receive from the host leader + memset(rankData.get(), 0, bufferSize); + recv(it.second.front(), + recvRank, + rankData.get(), + datatype, + count, + nullptr, + faabric::MPIMessage::REDUCE); - op_reduce(operation, datatype, count, rankData.get(), recvBuffer); + op_reduce( + operation, datatype, count, rankData.get(), recvBuffer); + } } } else if (sendRank == localLeader) { @@ -1110,9 +1219,9 @@ void MpiWorld::reduce(int sendRank, // that we do so in a copy of the send buffer, as the application // does not expect said buffer's contents to be modified. auto sendBufferCopy = std::make_unique(bufferSize); - memcpy(sendBufferCopy.get(), sendBuffer, bufferSize); + ::memcpy(sendBufferCopy.get(), sendBuffer, bufferSize); - for (const int r : localRanks) { + for (const int r : ranksForHost[thisHost]) { if (r == sendRank) { continue; } @@ -1302,7 +1411,7 @@ void MpiWorld::scan(int rank, // need also to be considered. size_t bufferSize = datatype->size * count; if (!isInPlace) { - memcpy(recvBuffer, sendBuffer, bufferSize); + ::memcpy(recvBuffer, sendBuffer, bufferSize); } if (rank > 0) { diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index a7870874d..9318c901e 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -626,11 +626,26 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, std::vector actual(thisWorldSize * nPerRank, -1); // Call gather for each rank other than the root (out of order) - int root = thisHostRankA; + int root; + std::vector orderedLocalGatherRanks; + + SECTION("Gather receiver is also local leader") + { + root = 0; + orderedLocalGatherRanks = { 1, 2, 0 }; + } + + SECTION("Gather receiver is not a local leader") + { + root = 1; + orderedLocalGatherRanks = { 1, 2, 0 }; + } + std::thread otherWorldThread([this, root, &rankData, nPerRank] { otherWorld.initialiseFromMsg(msg); - for (int rank : otherWorldRanks) { + std::vector orderedRemoteGatherRanks = { 4, 5, 3 }; + for (const int rank : orderedRemoteGatherRanks) { otherWorld.gather(rank, root, BYTES(rankData[rank].data()), @@ -645,7 +660,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, otherWorld.destroy(); }); - for (int rank : thisWorldRanks) { + for (const int rank : orderedLocalGatherRanks) { if (rank == root) { continue; } @@ -1282,4 +1297,124 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, otherWorld.destroy(); thisWorld.destroy(); } + +std::set getMsgCountsFromMessages( + std::vector> msgs) +{ + std::set counts; + for (const auto& msg : msgs) { + counts.insert(msg->count()); + } + + return counts; +} + +TEST_CASE_METHOD(RemoteMpiTestFixture, + "Test number of messages sent during gather", + "[mpi]") +{ + int worldSize = 4; + setWorldSizes(worldSize, 2, 2); + std::vector messageData = { 0, 1, 2 }; + int nPerRank = messageData.size(); + + // Init worlds + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(true); + thisWorld.broadcastHostsToRanks(); + REQUIRE(getMpiHostsToRanksMessages().size() == 1); + otherWorld.initialiseFromMsg(msg); + + std::set expectedSentMsgRanks; + std::set expectedSentMsgCounts; + int expectedNumMsgSent; + int sendRank; + int recvRank; + + SECTION("Call gather from receiver (local), and receiver is local leader") + { + recvRank = 0; + sendRank = recvRank; + expectedNumMsgSent = 0; + expectedSentMsgRanks = {}; + expectedSentMsgCounts = {}; + } + + SECTION( + "Call gather from receiver (local), and receiver is non-local leader") + { + recvRank = 1; + sendRank = recvRank; + expectedNumMsgSent = 0; + expectedSentMsgRanks = {}; + expectedSentMsgCounts = {}; + } + + SECTION("Call gather from non-receiver, colocated with receiver, and local " + "leader") + { + recvRank = 1; + sendRank = 0; + expectedNumMsgSent = 1; + expectedSentMsgRanks = { recvRank }; + expectedSentMsgCounts = { nPerRank }; + } + + SECTION("Call gather from non-receiver, colocated with receiver") + { + recvRank = 0; + sendRank = 1; + expectedNumMsgSent = 1; + expectedSentMsgRanks = { recvRank }; + expectedSentMsgCounts = { nPerRank }; + } + + SECTION("Call gather from non-receiver rank, not colocated with receiver, " + "but local leader") + { + recvRank = 0; + sendRank = 2; + expectedNumMsgSent = 1; + expectedSentMsgRanks = { recvRank }; + expectedSentMsgCounts = { 2 * nPerRank }; + } + + SECTION("Call gather from non-receiver rank, not colocated with receiver") + { + recvRank = 0; + sendRank = 3; + expectedNumMsgSent = 1; + expectedSentMsgRanks = { 2 }; + expectedSentMsgCounts = { nPerRank }; + } + + std::vector gatherData(worldSize * nPerRank); + if (sendRank < 2) { + thisWorld.gather(sendRank, + recvRank, + BYTES(messageData.data()), + MPI_INT, + nPerRank, + BYTES(gatherData.data()), + MPI_INT, + nPerRank); + } else { + otherWorld.gather(sendRank, + recvRank, + BYTES(messageData.data()), + MPI_INT, + nPerRank, + BYTES(gatherData.data()), + MPI_INT, + nPerRank); + } + auto msgs = getMpiMockedMessages(sendRank); + REQUIRE(msgs.size() == expectedNumMsgSent); + REQUIRE(getReceiversFromMessages(msgs) == expectedSentMsgRanks); + REQUIRE(getMsgCountsFromMessages(msgs) == expectedSentMsgCounts); + + faabric::util::setMockMode(false); + otherWorld.destroy(); + thisWorld.destroy(); +} }