Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marginal gains in send/recv fast-path #104

Merged
merged 3 commits into from
Jun 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,11 @@ class MpiWorld
std::string function;

std::shared_ptr<state::StateKeyValue> stateKV;
std::unordered_map<int, std::string> rankHostMap;
std::vector<std::string> rankHosts;

std::unordered_map<std::string, uint8_t*> windowPointerMap;

std::unordered_map<std::string, std::shared_ptr<InMemoryMpiQueue>>
localQueueMap;
std::vector<std::shared_ptr<InMemoryMpiQueue>> localQueues;

std::shared_ptr<faabric::scheduler::MpiAsyncThreadPool> threadPool;
int getMpiThreadPoolSize();
Expand All @@ -235,8 +234,10 @@ class MpiWorld
faabric::scheduler::FunctionCallClient& getFunctionCallClient(
const std::string& otherHost);

void checkRankOnThisHost(int rank);

void closeThreadLocalClients();

int getIndexForRanks(int sendRank, int recvRank);

void initLocalQueues();
};
}
131 changes: 59 additions & 72 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize)
for (const auto& h : hosts) {
faabric::transport::sendMpiHostRankMsg(h, hostRankMsg);
}

// Initialise the memory queues for message reception
initLocalQueues();
}

void MpiWorld::destroy()
Expand All @@ -138,10 +141,12 @@ void MpiWorld::destroy()
// clear them.
// Note - this means that an application with outstanding messages, i.e.
// send without recv, will block forever.
for (auto& k : localQueueMap) {
k.second->waitToDrain(-1);
for (auto& q : localQueues) {
if (q != nullptr) {
q->waitToDrain(-1);
}
}
localQueueMap.clear();
localQueues.clear();
}
}

Expand Down Expand Up @@ -186,44 +191,46 @@ void MpiWorld::initialiseFromMsg(const faabric::Message& msg, bool forceLocal)

// Sometimes for testing purposes we may want to initialise a world in the
// _same_ host we have created one (note that this would never happen in
// reality). If so, we skip the rank-host map broadcasting.
// reality). If so, we skip initialising resources already initialised
if (!forceLocal) {
// Block until we receive
faabric::MpiHostsToRanksMessage hostRankMsg =
faabric::transport::recvMpiHostRankMsg();
setAllRankHosts(hostRankMsg);

// Initialise the memory queues for message reception
initLocalQueues();
}
}

std::string MpiWorld::getHostForRank(int rank)
{
if (rankHostMap.find(rank) == rankHostMap.end()) {
logger->error("No known host for rank {}", rank);
throw std::runtime_error("No known host for rank");
assert(rankHosts.size() == size);

if (rank >= size) {
throw std::runtime_error(
fmt::format("Rank bigger than world size ({} > {})", rank, size));
}

{
faabric::util::SharedLock lock(worldMutex);
return rankHostMap[rank];
std::string host = rankHosts[rank];
if (host.empty()) {
throw std::runtime_error(
fmt::format("No host found for rank {}", rank));
}

return host;
}

// Prepare the host-rank map with a vector containing _all_ ranks
// Note - this method should be called by only one rank. This is enforced in
// the world registry
void MpiWorld::setAllRankHosts(const faabric::MpiHostsToRanksMessage& msg)
{
// Assert we are only setting the values once
assert(rankHosts.size() == 0);

assert(msg.hosts().size() == size);
faabric::util::FullLock lock(worldMutex);
for (int i = 0; i < size; i++) {
auto it = rankHostMap.try_emplace(i, msg.hosts().at(i));
if (!it.second) {
logger->error("Tried to map host ({}) to rank ({}), but rank was "
"already mapped to host ({})",
msg.hosts().at(i),
i + 1,
rankHostMap[i]);
throw std::runtime_error("Rank already mapped to host");
}
}
rankHosts = { msg.hosts().begin(), msg.hosts().end() };
}

void MpiWorld::getCartesianRank(int rank,
Expand Down Expand Up @@ -413,10 +420,9 @@ void MpiWorld::send(int sendRank,
int count,
faabric::MPIMessage::MPIMessageType messageType)
{
if (recvRank > this->size - 1) {
throw std::runtime_error(fmt::format(
"Rank {} bigger than world size {}", recvRank, this->size));
}
// Work out whether the message is sent locally or to another host
const std::string otherHost = getHostForRank(recvRank);
bool isLocal = otherHost == thisHost;

// Generate a message ID
int msgId = (int)faabric::util::generateGid();
Expand All @@ -431,10 +437,6 @@ void MpiWorld::send(int sendRank,
m->set_count(count);
m->set_messagetype(messageType);

// Work out whether the message is sent locally or to another host
const std::string otherHost = getHostForRank(recvRank);
bool isLocal = otherHost == thisHost;

// Set up message data
if (count > 0 && buffer != nullptr) {
m->set_buffer(buffer, dataType->size * count);
Expand Down Expand Up @@ -468,21 +470,10 @@ void MpiWorld::recv(int sendRank,
std::shared_ptr<faabric::MPIMessage> m =
getLocalQueue(sendRank, recvRank)->dequeue();

if (messageType != m->messagetype()) {
logger->error(
"Message types mismatched on {}->{} (expected={}, got={})",
sendRank,
recvRank,
messageType,
m->messagetype());
throw std::runtime_error("Mismatched message types");
}

if (m->count() > count) {
logger->error(
"Message too long for buffer (msg={}, buffer={})", m->count(), count);
throw std::runtime_error("Message too long");
}
// Assert message integrity
// Note - this checks won't happen in Release builds
assert(m->messagetype() == messageType);
assert(m->count() <= count);

// TODO - avoid copy here
// Copy message data
Expand Down Expand Up @@ -1119,23 +1110,34 @@ void MpiWorld::enqueueMessage(faabric::MPIMessage& msg)
std::shared_ptr<InMemoryMpiQueue> MpiWorld::getLocalQueue(int sendRank,
int recvRank)
{
checkRankOnThisHost(recvRank);
assert(getHostForRank(recvRank) == thisHost);
assert(localQueues.size() == size * size);

std::string key = std::to_string(sendRank) + "_" + std::to_string(recvRank);
if (localQueueMap.find(key) == localQueueMap.end()) {
faabric::util::FullLock lock(worldMutex);
return localQueues[getIndexForRanks(sendRank, recvRank)];
}

if (localQueueMap.find(key) == localQueueMap.end()) {
auto mq = new InMemoryMpiQueue();
localQueueMap.emplace(
std::pair<std::string, InMemoryMpiQueue*>(key, mq));
// We pre-allocate all _potentially_ necessary queues in advance. Queues are
// necessary to _receive_ messages, thus we initialise all queues whose
// corresponding receiver is local to this host
// Note - the queues themselves perform concurrency control
void MpiWorld::initLocalQueues()
{
// Assert we only allocate queues once
assert(localQueues.size() == 0);
localQueues.resize(size * size);
for (int recvRank = 0; recvRank < size; recvRank++) {
if (getHostForRank(recvRank) == thisHost) {
for (int sendRank = 0; sendRank < size; sendRank++) {
localQueues[getIndexForRanks(sendRank, recvRank)] =
std::make_shared<InMemoryMpiQueue>();
}
}
}
}

{
faabric::util::SharedLock lock(worldMutex);
return localQueueMap[key];
}
int MpiWorld::getIndexForRanks(int sendRank, int recvRank)
{
return sendRank * size + recvRank;
}

void MpiWorld::rmaGet(int sendRank,
Expand Down Expand Up @@ -1227,21 +1229,6 @@ long MpiWorld::getLocalQueueSize(int sendRank, int recvRank)
return queue->size();
}

void MpiWorld::checkRankOnThisHost(int rank)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if we decide to keep the if statements, we can delete this method and keep the check in getLocalQueue through getHostForRank.

{
// Check if we know about this rank on this host
if (rankHostMap.count(rank) == 0) {
logger->error("No mapping found for rank {} on this host", rank);
throw std::runtime_error("No mapping found for rank");
} else if (rankHostMap[rank] != thisHost) {
logger->error("Trying to access rank {} on {} but it's on {}",
rank,
thisHost,
rankHostMap[rank]);
throw std::runtime_error("Accessing in-memory queue for remote rank");
}
}

void MpiWorld::createWindow(const int winRank,
const int winSize,
uint8_t* windowPtr)
Expand Down
58 changes: 0 additions & 58 deletions tests/test/scheduler/test_mpi_world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,19 +301,6 @@ TEST_CASE("Test send and recv on same host", "[mpi]")
REQUIRE(status.bytesSize == messageData.size() * sizeof(int));
}

SECTION("Test recv with type missmatch")
{
// Receive a message from a different type
auto buffer = new int[messageData.size()];
REQUIRE_THROWS(world.recv(rankA1,
rankA2,
BYTES(buffer),
MPI_INT,
messageData.size(),
nullptr,
faabric::MPIMessage::SENDRECV));
}

tearDown({ &world });
}

Expand Down Expand Up @@ -665,51 +652,6 @@ TEST_CASE("Test probe", "[mpi]")
tearDown({ &world });
}

TEST_CASE("Test can't get in-memory queue for non-local ranks", "[mpi]")
{
cleanFaabric();

std::string otherHost = LOCALHOST;

auto& sch = faabric::scheduler::getScheduler();

// Force the scheduler to initialise a world in the remote host by setting
// a worldSize bigger than the slots available locally
int worldSize = 4;
faabric::HostResources localResources;
localResources.set_slots(2);
localResources.set_usedslots(1);
faabric::HostResources otherResources;
otherResources.set_slots(2);

// Set up a remote host
sch.addHostToGlobalSet(otherHost);

// Mock everything to make sure the other host has resources as well
faabric::util::setMockMode(true);
sch.setThisHostResources(localResources);
faabric::scheduler::queueResourceResponse(otherHost, otherResources);

faabric::Message msg = faabric::util::messageFactory(user, func);
msg.set_mpiworldsize(worldSize);
scheduler::MpiWorld worldA;
worldA.create(msg, worldId, worldSize);

scheduler::MpiWorld worldB;
worldB.overrideHost(otherHost);
worldB.initialiseFromMsg(msg);

// Check that we can't access rank on another host locally
REQUIRE_THROWS(worldA.getLocalQueue(0, 2));

// Double check even when we've retrieved the rank
REQUIRE(worldA.getHostForRank(2) == otherHost);
REQUIRE_THROWS(worldA.getLocalQueue(0, 2));

faabric::util::setMockMode(false);
tearDown({ &worldA, &worldB });
}

TEST_CASE("Check sending to invalid rank", "[mpi]")
{
cleanFaabric();
Expand Down