Skip to content

Commit

Permalink
State-less MPI rank management (#103)
Browse files Browse the repository at this point in the history
* removing all state-related rank processing, use messages instead. this could potentially allow us to remove the barrier from MPI_init

* updating tests to remove all references to registerRank

* move send and recv mpi to transport

* formatting + cleanup

* updating original message to avoid bug in faasm + cleanup

* uncomment tests

* pr comments
  • Loading branch information
csegarragonz authored Jun 7, 2021
1 parent 2be5d88 commit c9716ba
Show file tree
Hide file tree
Showing 15 changed files with 315 additions and 301 deletions.
2 changes: 1 addition & 1 deletion include/faabric/scheduler/MpiContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class MpiContext
public:
MpiContext();

void createWorld(const faabric::Message& msg);
int createWorld(const faabric::Message& msg);

void joinWorld(const faabric::Message& msg);

Expand Down
20 changes: 4 additions & 16 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,22 @@ namespace faabric::scheduler {
typedef faabric::util::Queue<std::shared_ptr<faabric::MPIMessage>>
InMemoryMpiQueue;

struct MpiWorldState
{
int worldSize;
};

std::string getWorldStateKey(int worldId);

std::string getRankStateKey(int worldId, int rankId);

class MpiWorld
{
public:
MpiWorld();

void create(const faabric::Message& call, int newId, int newSize);

void initialiseFromState(const faabric::Message& msg, int worldId);

void registerRank(int rank);
void initialiseFromMsg(const faabric::Message& msg,
bool forceLocal = false);

std::string getHostForRank(int rank);

void setAllRankHosts(const faabric::MpiHostsToRanksMessage& msg);

std::string getUser();

std::string getFunction();
Expand Down Expand Up @@ -238,17 +232,11 @@ class MpiWorld

std::vector<int> cartProcsPerDim;

void setUpStateKV();

std::shared_ptr<state::StateKeyValue> getRankHostState(int rank);

faabric::scheduler::FunctionCallClient& getFunctionCallClient(
const std::string& otherHost);

void checkRankOnThisHost(int rank);

void pushToState();

void closeThreadLocalClients();
};
}
3 changes: 1 addition & 2 deletions include/faabric/scheduler/MpiWorldRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ class MpiWorldRegistry
int worldId,
std::string hostOverride = "");

scheduler::MpiWorld& getOrInitialiseWorld(const faabric::Message& msg,
int worldId);
scheduler::MpiWorld& getOrInitialiseWorld(const faabric::Message& msg);

scheduler::MpiWorld& getWorld(int worldId);

Expand Down
13 changes: 13 additions & 0 deletions include/faabric/transport/MpiMessageEndpoint.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#pragma once

#include <faabric/proto/faabric.pb.h>
#include <faabric/transport/MessageEndpoint.h>
#include <faabric/transport/common.h>
#include <faabric/transport/macros.h>

namespace faabric::transport {
faabric::MpiHostsToRanksMessage recvMpiHostRankMsg();

void sendMpiHostRankMsg(const std::string& hostIn,
const faabric::MpiHostsToRanksMessage msg);
}
2 changes: 2 additions & 0 deletions include/faabric/transport/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@
#define MPI_MESSAGE_PORT 8005
#define SNAPSHOT_PORT 8006
#define REPLY_PORT_OFFSET 100

#define MPI_PORT 8800
3 changes: 1 addition & 2 deletions src/mpi_native/mpi_native.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ faabric::Message* getExecutingCall()

faabric::scheduler::MpiWorld& getExecutingWorld()
{
int worldId = executingContext.getWorldId();
faabric::scheduler::MpiWorldRegistry& reg =
faabric::scheduler::getMpiWorldRegistry();
return reg.getOrInitialiseWorld(*getExecutingCall(), worldId);
return reg.getOrInitialiseWorld(*getExecutingCall());
}

static void notImplemented(const std::string& funcName)
Expand Down
7 changes: 7 additions & 0 deletions src/proto/faabric.proto
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ message MPIMessage {
bytes buffer = 8;
}

// Instead of sending a map, or a list of ranks, we use the repeated string
// index as rank. Note that protobuf guarantess in-order delivery of repeated
// fields.
message MpiHostsToRanksMessage {
repeated string hosts = 1;
}

message Message {
int32 id = 1;
int32 appId = 2;
Expand Down
8 changes: 5 additions & 3 deletions src/scheduler/MpiContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ MpiContext::MpiContext()
, worldId(-1)
{}

void MpiContext::createWorld(const faabric::Message& msg)
int MpiContext::createWorld(const faabric::Message& msg)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();

Expand All @@ -32,6 +32,9 @@ void MpiContext::createWorld(const faabric::Message& msg)
// Set up this context
isMpi = true;
rank = 0;

// Return the world id to store it in the original message
return worldId;
}

void MpiContext::joinWorld(const faabric::Message& msg)
Expand All @@ -47,8 +50,7 @@ void MpiContext::joinWorld(const faabric::Message& msg)

// Register with the world
MpiWorldRegistry& registry = getMpiWorldRegistry();
MpiWorld& world = registry.getOrInitialiseWorld(msg, worldId);
world.registerRank(rank);
registry.getOrInitialiseWorld(msg);
}

bool MpiContext::getIsMpi()
Expand Down
152 changes: 53 additions & 99 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <faabric/scheduler/MpiWorld.h>
#include <faabric/scheduler/Scheduler.h>
#include <faabric/state/State.h>
#include <faabric/transport/MpiMessageEndpoint.h>
#include <faabric/util/environment.h>
#include <faabric/util/func.h>
#include <faabric/util/gids.h>
Expand Down Expand Up @@ -35,39 +36,12 @@ std::string getWorldStateKey(int worldId)
return "mpi_world_" + std::to_string(worldId);
}

std::string getRankStateKey(int worldId, int rankId)
{
if (worldId <= 0 || rankId < 0) {
throw std::runtime_error(
fmt::format("World ID must be >0 and rank ID must be >=0 ({}, {})",
worldId,
rankId));
}
return "mpi_rank_" + std::to_string(worldId) + "_" + std::to_string(rankId);
}

std::string getWindowStateKey(int worldId, int rank, size_t size)
{
return "mpi_win_" + std::to_string(worldId) + "_" + std::to_string(rank) +
"_" + std::to_string(size);
}

void MpiWorld::setUpStateKV()
{
if (stateKV == nullptr) {
state::State& state = state::getGlobalState();
std::string stateKey = getWorldStateKey(id);
stateKV = state.getKV(user, stateKey, sizeof(MpiWorldState));
}
}

std::shared_ptr<state::StateKeyValue> MpiWorld::getRankHostState(int rank)
{
state::State& state = state::getGlobalState();
std::string stateKey = getRankStateKey(id, rank);
return state.getKV(user, stateKey, MPI_HOST_STATE_LEN);
}

faabric::scheduler::FunctionCallClient& MpiWorld::getFunctionCallClient(
const std::string& otherHost)
{
Expand Down Expand Up @@ -114,13 +88,6 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize)
threadPool = std::make_shared<faabric::scheduler::MpiAsyncThreadPool>(
getMpiThreadPoolSize());

// Write this to state
setUpStateKV();
pushToState();

// Register this as the master
registerRank(0);

auto& sch = faabric::scheduler::getScheduler();

// Dispatch all the chained calls
Expand All @@ -133,9 +100,29 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize)
msg.set_ismpi(true);
msg.set_mpiworldid(id);
msg.set_mpirank(i + 1);
msg.set_mpiworldsize(size);
}

sch.callFunctions(req);
// Send the init messages (note that message i corresponds to rank i+1)
std::vector<std::string> executedAt = sch.callFunctions(req);
assert(executedAt.size() == size - 1);

// Prepend this host for rank 0
executedAt.insert(executedAt.begin(), thisHost);

// Register hosts to rank mappings on this host
faabric::MpiHostsToRanksMessage hostRankMsg;
*hostRankMsg.mutable_hosts() = { executedAt.begin(), executedAt.end() };
setAllRankHosts(hostRankMsg);

// Set up a list of hosts to broadcast to (excluding this host)
std::set<std::string> hosts(executedAt.begin(), executedAt.end());
hosts.erase(thisHost);

// Do the broadcast
for (const auto& h : hosts) {
faabric::transport::sendMpiHostRankMsg(h, hostRankMsg);
}
}

void MpiWorld::destroy()
Expand All @@ -147,12 +134,6 @@ void MpiWorld::destroy()
// Note - we are deliberately not deleting the KV in the global state
// TODO - find a way to do this only from the master client

for (auto& s : rankHostMap) {
const std::shared_ptr<state::StateKeyValue>& rankState =
getRankHostState(s.first);
state::getGlobalState().deleteKV(rankState->user, rankState->key);
}

// Wait (forever) until all ranks are done consuming their queues to
// clear them.
// Note - this means that an application with outstanding messages, i.e.
Expand Down Expand Up @@ -193,77 +174,32 @@ void MpiWorld::closeThreadLocalClients()
functionCallClients.clear();
}

void MpiWorld::initialiseFromState(const faabric::Message& msg, int worldId)
void MpiWorld::initialiseFromMsg(const faabric::Message& msg, bool forceLocal)
{
id = worldId;
id = msg.mpiworldid();
user = msg.user();
function = msg.function();
size = msg.mpiworldsize();

setUpStateKV();

// Read from state
MpiWorldState s{};
stateKV->pull();
stateKV->get(BYTES(&s));
size = s.worldSize;
threadPool = std::make_shared<faabric::scheduler::MpiAsyncThreadPool>(
getMpiThreadPoolSize());
}

void MpiWorld::pushToState()
{
// Write to state
MpiWorldState s{
.worldSize = this->size,
};

stateKV->set(BYTES(&s));
stateKV->pushFull();
}

void MpiWorld::registerRank(int rank)
{
{
faabric::util::FullLock lock(worldMutex);
rankHostMap[rank] = thisHost;
// Sometimes for testing purposes we may want to initialise a world in the
// _same_ host we have created one (note that this would never happen in
// reality). If so, we skip the rank-host map broadcasting.
if (!forceLocal) {
// Block until we receive
faabric::MpiHostsToRanksMessage hostRankMsg =
faabric::transport::recvMpiHostRankMsg();
setAllRankHosts(hostRankMsg);
}

// Note that the host name may be shorter than the buffer, so we need to pad
// with nulls
uint8_t hostBytesBuffer[MPI_HOST_STATE_LEN];
memset(hostBytesBuffer, '\0', MPI_HOST_STATE_LEN);
::strcpy((char*)hostBytesBuffer, thisHost.c_str());

const std::shared_ptr<state::StateKeyValue>& kv = getRankHostState(rank);
kv->set(hostBytesBuffer);
kv->pushFull();
}

std::string MpiWorld::getHostForRank(int rank)
{
// Pull from state if not present
if (rankHostMap.find(rank) == rankHostMap.end()) {
faabric::util::FullLock lock(worldMutex);

if (rankHostMap.find(rank) == rankHostMap.end()) {
auto buffer = new uint8_t[MPI_HOST_STATE_LEN];
const std::shared_ptr<state::StateKeyValue>& kv =
getRankHostState(rank);
kv->get(buffer);

char* bufferChar = reinterpret_cast<char*>(buffer);
if (bufferChar[0] == '\0') {
// No entry for other rank
throw std::runtime_error(
fmt::format("No host entry for rank {}", rank));
}

// Note - we rely on C strings detecting the null terminator here,
// assuming the host will either be an IP or string of alphanumeric
// characters and dots
std::string otherHost(bufferChar);
rankHostMap[rank] = otherHost;
}
logger->error("No known host for rank {}", rank);
throw std::runtime_error("No known host for rank");
}

{
Expand All @@ -272,6 +208,24 @@ std::string MpiWorld::getHostForRank(int rank)
}
}

// Prepare the host-rank map with a vector containing _all_ ranks
void MpiWorld::setAllRankHosts(const faabric::MpiHostsToRanksMessage& msg)
{
assert(msg.hosts().size() == size);
faabric::util::FullLock lock(worldMutex);
for (int i = 0; i < size; i++) {
auto it = rankHostMap.try_emplace(i, msg.hosts().at(i));
if (!it.second) {
logger->error("Tried to map host ({}) to rank ({}), but rank was "
"already mapped to host ({})",
msg.hosts().at(i),
i + 1,
rankHostMap[i]);
throw std::runtime_error("Rank already mapped to host");
}
}
}

void MpiWorld::getCartesianRank(int rank,
int maxDims,
const int* dims,
Expand Down
6 changes: 3 additions & 3 deletions src/scheduler/MpiWorldRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ scheduler::MpiWorld& MpiWorldRegistry::createWorld(const faabric::Message& msg,
return worldMap[worldId];
}

MpiWorld& MpiWorldRegistry::getOrInitialiseWorld(const faabric::Message& msg,
int worldId)
MpiWorld& MpiWorldRegistry::getOrInitialiseWorld(const faabric::Message& msg)
{
// Create world locally if not exists
int worldId = msg.mpiworldid();
if (worldMap.find(worldId) == worldMap.end()) {
faabric::util::FullLock lock(registryMutex);
if (worldMap.find(worldId) == worldMap.end()) {
MpiWorld& world = worldMap[worldId];
world.initialiseFromState(msg, worldId);
world.initialiseFromMsg(msg);
}
}

Expand Down
Loading

0 comments on commit c9716ba

Please sign in to comment.