Skip to content

Commit

Permalink
amend reduce algorithm to mimick broadcast and reduce the number of c…
Browse files Browse the repository at this point in the history
…ross-host messages
  • Loading branch information
csegarragonz committed Dec 1, 2021
1 parent cdd122a commit f9e0e86
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 16 deletions.
4 changes: 2 additions & 2 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ class MpiWorld
faabric_datatype_t* recvType,
int recvCount);

void reduce(int sendRank,
int recvRank,
void reduce(int thisRank,
int rootRank,
uint8_t* sendBuffer,
uint8_t* recvBuffer,
faabric_datatype_t* datatype,
Expand Down
67 changes: 53 additions & 14 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1070,38 +1070,65 @@ void MpiWorld::awaitAsyncRequest(int requestId)
umb->deleteMessage(msgIt);
}

void MpiWorld::reduce(int sendRank,
int recvRank,
void MpiWorld::reduce(int thisRank,
int rootRank,
uint8_t* sendBuffer,
uint8_t* recvBuffer,
faabric_datatype_t* datatype,
int count,
faabric_op_t* operation)
{
// If we're the receiver, await inputs
if (sendRank == recvRank) {
SPDLOG_TRACE("MPI - reduce ({}) all -> {}", operation->id, recvRank);
size_t bufferSize = datatype->size * count;
auto rankData = std::make_unique<uint8_t[]>(bufferSize);

size_t bufferSize = datatype->size * count;
if (thisRank == rootRank) {
// If we're the root of the reduce, await inputs from our local ranks
// (besides ourselves) and remote masters
SPDLOG_TRACE("MPI - reduce ({}) all -> {}", operation->id, rootRank);

bool isInPlace = sendBuffer == recvBuffer;
// Work out the list of all the ranks we need to wait for
std::vector<int> rootRecvRanks = localRanks;
std::remove(rootRecvRanks.begin(), rootRecvRanks.end(), thisRank);
rootRecvRanks.insert(
rootRecvRanks.end(), remoteMasters.begin(), remoteMasters.end());

// If not receiving in-place, initialize the receive buffer to the send
// buffer values. This prevents issues when 0-initializing for operators
// like the minimum, or product.
// If we're receiving from ourselves and in-place, our work is
// already done and the results are written in the recv buffer
bool isInPlace = sendBuffer == recvBuffer;
if (!isInPlace) {
memcpy(recvBuffer, sendBuffer, bufferSize);
}

auto rankData = std::make_unique<uint8_t[]>(bufferSize);
for (int r = 0; r < size; r++) {
for (const int r : rootRecvRanks) {
// Work out the data for this rank
memset(rankData.get(), 0, bufferSize);
if (r != recvRank) {
recv(r,
thisRank,
rankData.get(),
datatype,
count,
nullptr,
faabric::MPIMessage::REDUCE);

op_reduce(operation, datatype, count, rankData.get(), recvBuffer);
}

} else if (thisRank == localMaster) {
// If we are the local master (but not the root of the reduce) and the
// root rank is not local to us, do a reduce with the data of all our
// local ranks, and then send the result to the root
if (getHostForRank(rootRank) != thisHost) {
for (const int r : localRanks) {
if (r == thisRank) {
continue;
}

memset(rankData.get(), 0, bufferSize);
recv(r,
recvRank,
thisRank,
rankData.get(),
datatype,
count,
Expand All @@ -1113,10 +1140,22 @@ void MpiWorld::reduce(int sendRank,
}
}

// Send to the root rank
send(thisRank,
rootRank,
sendBuffer,
datatype,
count,
faabric::MPIMessage::REDUCE);
} else {
// Do the sending
send(sendRank,
recvRank,
// If we are a non-root and non-local-master rank, we send our data
// for reduction either to our local master or the root, depending on
// whether we are colocated with the root rank or not
int realRecvRank =
getHostForRank(rootRank) == thisHost ? rootRank : localMaster;

send(thisRank,
realRecvRank,
sendBuffer,
datatype,
count,
Expand Down
97 changes: 97 additions & 0 deletions tests/test/scheduler/test_remote_mpi_worlds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1023,4 +1023,101 @@ TEST_CASE_METHOD(RemoteMpiTestFixture,
otherWorld.destroy();
thisWorld.destroy();
}

TEST_CASE_METHOD(RemoteMpiTestFixture,
"Test number of messages sent during reduce",
"[mpi]")
{
// Register three ranks
setWorldSizes(4, 2, 2);

// Init worlds
MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId);
faabric::util::setMockMode(true);
thisWorld.broadcastHostsToRanks();
REQUIRE(getMpiHostsToRanksMessages().size() == 1);
otherWorld.initialiseFromMsg(msg);

// Call broadcast and check sent messages
std::set<int> expectedRecvRanks;
int expectedNumMsg;
int sendRank;
int rootRank;

SECTION("Check from root rank (local), and root is local master")
{
rootRank = 0;
sendRank = rootRank;
expectedNumMsg = 0;
expectedRecvRanks = {};
}

SECTION("Check from root rank (local), and root is non-local master")
{
rootRank = 1;
sendRank = rootRank;
expectedNumMsg = 0;
expectedRecvRanks = {};
}

SECTION("Check from local non-root rank, and non-root is local master")
{
rootRank = 1;
sendRank = 0;
expectedNumMsg = 1;
expectedRecvRanks = { rootRank };
}

SECTION("Check from local non-root rank, and non-root is non-local-master")
{
rootRank = 0;
sendRank = 1;
expectedNumMsg = 1;
expectedRecvRanks = { rootRank };
}

SECTION("Check from remote rank, and remote rank is local master")
{
rootRank = 0;
sendRank = 2;
expectedNumMsg = 1;
expectedRecvRanks = { rootRank };
}

SECTION("Check from remote rank, and remote rank is not local master")
{
rootRank = 0;
sendRank = 3;
expectedNumMsg = 1;
expectedRecvRanks = { 2 };
}

// Check for root
std::vector<int> messageData = { 0, 1, 2 };
std::vector<int> recvData(messageData.size());
if (sendRank < 2) {
thisWorld.reduce(sendRank,
rootRank,
BYTES(messageData.data()),
BYTES(recvData.data()),
MPI_INT,
messageData.size(),
MPI_SUM);
} else {
otherWorld.reduce(sendRank,
rootRank,
BYTES(messageData.data()),
BYTES(recvData.data()),
MPI_INT,
messageData.size(),
MPI_SUM);
}
auto msgs = getMpiMockedMessages(sendRank);
REQUIRE(msgs.size() == expectedNumMsg);
REQUIRE(getReceiversFromMessages(msgs) == expectedRecvRanks);

faabric::util::setMockMode(false);
otherWorld.destroy();
thisWorld.destroy();
}
}

0 comments on commit f9e0e86

Please sign in to comment.