Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State-less MPI rank management #103

Merged
merged 7 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/faabric/scheduler/MpiContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class MpiContext
public:
MpiContext();

void createWorld(const faabric::Message& msg);
int createWorld(const faabric::Message& msg);

void joinWorld(const faabric::Message& msg);

Expand Down
20 changes: 4 additions & 16 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,22 @@ namespace faabric::scheduler {
typedef faabric::util::Queue<std::shared_ptr<faabric::MPIMessage>>
InMemoryMpiQueue;

struct MpiWorldState
{
int worldSize;
};

std::string getWorldStateKey(int worldId);

std::string getRankStateKey(int worldId, int rankId);

class MpiWorld
{
public:
MpiWorld();

void create(const faabric::Message& call, int newId, int newSize);

void initialiseFromState(const faabric::Message& msg, int worldId);

void registerRank(int rank);
void initialiseFromMsg(const faabric::Message& msg,
bool forceLocal = false);

std::string getHostForRank(int rank);

void setAllRankHosts(const faabric::MpiHostsToRanksMessage& msg);

std::string getUser();

std::string getFunction();
Expand Down Expand Up @@ -238,17 +232,11 @@ class MpiWorld

std::vector<int> cartProcsPerDim;

void setUpStateKV();

std::shared_ptr<state::StateKeyValue> getRankHostState(int rank);

faabric::scheduler::FunctionCallClient& getFunctionCallClient(
const std::string& otherHost);

void checkRankOnThisHost(int rank);

void pushToState();

void closeThreadLocalClients();
};
}
3 changes: 1 addition & 2 deletions include/faabric/scheduler/MpiWorldRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ class MpiWorldRegistry
int worldId,
std::string hostOverride = "");

scheduler::MpiWorld& getOrInitialiseWorld(const faabric::Message& msg,
int worldId);
scheduler::MpiWorld& getOrInitialiseWorld(const faabric::Message& msg);

scheduler::MpiWorld& getWorld(int worldId);

Expand Down
13 changes: 13 additions & 0 deletions include/faabric/transport/MpiMessageEndpoint.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#pragma once
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the implementation of these is so simple, that they are alright with just being abstract methods. However, I thought they belonged in transport rather than scheduler.


#include <faabric/proto/faabric.pb.h>
#include <faabric/transport/MessageEndpoint.h>
#include <faabric/transport/common.h>
#include <faabric/transport/macros.h>

namespace faabric::transport {
faabric::MpiHostsToRanksMessage recvMpiHostRankMsg();

void sendMpiHostRankMsg(const std::string& hostIn,
const faabric::MpiHostsToRanksMessage msg);
}
2 changes: 2 additions & 0 deletions include/faabric/transport/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@
#define MPI_MESSAGE_PORT 8005
#define SNAPSHOT_PORT 8006
#define REPLY_PORT_OFFSET 100

#define MPI_PORT 8800
Copy link
Collaborator

@Shillaker Shillaker Jun 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is the MPI_PORT different to the existing MPI_MESSAGE_PORT and how come it's so much higher in the port range than the others? It would be cleanest to keep the range of ports we use as narrow as possible (e.g. could this just be 8007?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After our offline discussion, I'll leave the port like this as it will change in coming PRs.

3 changes: 1 addition & 2 deletions src/mpi_native/mpi_native.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ faabric::Message* getExecutingCall()

faabric::scheduler::MpiWorld& getExecutingWorld()
{
int worldId = executingContext.getWorldId();
faabric::scheduler::MpiWorldRegistry& reg =
faabric::scheduler::getMpiWorldRegistry();
return reg.getOrInitialiseWorld(*getExecutingCall(), worldId);
return reg.getOrInitialiseWorld(*getExecutingCall());
}

static void notImplemented(const std::string& funcName)
Expand Down
7 changes: 7 additions & 0 deletions src/proto/faabric.proto
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ message MPIMessage {
bytes buffer = 8;
}

// Instead of sending a map, or a list of ranks, we use the repeated string
// index as rank. Note that protobuf guarantess in-order delivery of repeated
// fields.
message MpiHostsToRanksMessage {
repeated string hosts = 1;
}

message Message {
int32 id = 1;
int32 appId = 2;
Expand Down
8 changes: 5 additions & 3 deletions src/scheduler/MpiContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ MpiContext::MpiContext()
, worldId(-1)
{}

void MpiContext::createWorld(const faabric::Message& msg)
int MpiContext::createWorld(const faabric::Message& msg)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();

Expand All @@ -32,6 +32,9 @@ void MpiContext::createWorld(const faabric::Message& msg)
// Set up this context
isMpi = true;
rank = 0;

// Return the world id to store it in the original message
return worldId;
}

void MpiContext::joinWorld(const faabric::Message& msg)
Expand All @@ -47,8 +50,7 @@ void MpiContext::joinWorld(const faabric::Message& msg)

// Register with the world
MpiWorldRegistry& registry = getMpiWorldRegistry();
MpiWorld& world = registry.getOrInitialiseWorld(msg, worldId);
world.registerRank(rank);
registry.getOrInitialiseWorld(msg);
}

bool MpiContext::getIsMpi()
Expand Down
152 changes: 53 additions & 99 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <faabric/scheduler/MpiWorld.h>
#include <faabric/scheduler/Scheduler.h>
#include <faabric/state/State.h>
#include <faabric/transport/MpiMessageEndpoint.h>
#include <faabric/util/environment.h>
#include <faabric/util/func.h>
#include <faabric/util/gids.h>
Expand Down Expand Up @@ -35,39 +36,12 @@ std::string getWorldStateKey(int worldId)
return "mpi_world_" + std::to_string(worldId);
}

std::string getRankStateKey(int worldId, int rankId)
{
if (worldId <= 0 || rankId < 0) {
throw std::runtime_error(
fmt::format("World ID must be >0 and rank ID must be >=0 ({}, {})",
worldId,
rankId));
}
return "mpi_rank_" + std::to_string(worldId) + "_" + std::to_string(rankId);
}

std::string getWindowStateKey(int worldId, int rank, size_t size)
{
return "mpi_win_" + std::to_string(worldId) + "_" + std::to_string(rank) +
"_" + std::to_string(size);
}

void MpiWorld::setUpStateKV()
{
if (stateKV == nullptr) {
state::State& state = state::getGlobalState();
std::string stateKey = getWorldStateKey(id);
stateKV = state.getKV(user, stateKey, sizeof(MpiWorldState));
}
}

std::shared_ptr<state::StateKeyValue> MpiWorld::getRankHostState(int rank)
{
state::State& state = state::getGlobalState();
std::string stateKey = getRankStateKey(id, rank);
return state.getKV(user, stateKey, MPI_HOST_STATE_LEN);
}

faabric::scheduler::FunctionCallClient& MpiWorld::getFunctionCallClient(
const std::string& otherHost)
{
Expand Down Expand Up @@ -114,13 +88,6 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize)
threadPool = std::make_shared<faabric::scheduler::MpiAsyncThreadPool>(
getMpiThreadPoolSize());

// Write this to state
setUpStateKV();
pushToState();

// Register this as the master
registerRank(0);

auto& sch = faabric::scheduler::getScheduler();

// Dispatch all the chained calls
Expand All @@ -133,9 +100,29 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize)
msg.set_ismpi(true);
msg.set_mpiworldid(id);
msg.set_mpirank(i + 1);
msg.set_mpiworldsize(size);
}

sch.callFunctions(req);
// Send the init messages (note that message i corresponds to rank i+1)
std::vector<std::string> executedAt = sch.callFunctions(req);
assert(executedAt.size() == size - 1);

// Prepend this host for rank 0
executedAt.insert(executedAt.begin(), thisHost);

// Register hosts to rank mappings on this host
faabric::MpiHostsToRanksMessage hostRankMsg;
*hostRankMsg.mutable_hosts() = { executedAt.begin(), executedAt.end() };
setAllRankHosts(hostRankMsg);

// Set up a list of hosts to broadcast to (excluding this host)
std::set<std::string> hosts(executedAt.begin(), executedAt.end());
hosts.erase(thisHost);

// Do the broadcast
for (const auto& h : hosts) {
faabric::transport::sendMpiHostRankMsg(h, hostRankMsg);
}
}

void MpiWorld::destroy()
Expand All @@ -147,12 +134,6 @@ void MpiWorld::destroy()
// Note - we are deliberately not deleting the KV in the global state
// TODO - find a way to do this only from the master client

for (auto& s : rankHostMap) {
const std::shared_ptr<state::StateKeyValue>& rankState =
getRankHostState(s.first);
state::getGlobalState().deleteKV(rankState->user, rankState->key);
}

// Wait (forever) until all ranks are done consuming their queues to
// clear them.
// Note - this means that an application with outstanding messages, i.e.
Expand Down Expand Up @@ -193,77 +174,32 @@ void MpiWorld::closeThreadLocalClients()
functionCallClients.clear();
}

void MpiWorld::initialiseFromState(const faabric::Message& msg, int worldId)
void MpiWorld::initialiseFromMsg(const faabric::Message& msg, bool forceLocal)
{
id = worldId;
id = msg.mpiworldid();
user = msg.user();
function = msg.function();
size = msg.mpiworldsize();

setUpStateKV();

// Read from state
MpiWorldState s{};
stateKV->pull();
stateKV->get(BYTES(&s));
size = s.worldSize;
threadPool = std::make_shared<faabric::scheduler::MpiAsyncThreadPool>(
getMpiThreadPoolSize());
}

void MpiWorld::pushToState()
{
// Write to state
MpiWorldState s{
.worldSize = this->size,
};

stateKV->set(BYTES(&s));
stateKV->pushFull();
}

void MpiWorld::registerRank(int rank)
{
{
faabric::util::FullLock lock(worldMutex);
rankHostMap[rank] = thisHost;
// Sometimes for testing purposes we may want to initialise a world in the
// _same_ host we have created one (note that this would never happen in
// reality). If so, we skip the rank-host map broadcasting.
if (!forceLocal) {
// Block until we receive
faabric::MpiHostsToRanksMessage hostRankMsg =
faabric::transport::recvMpiHostRankMsg();
setAllRankHosts(hostRankMsg);
}

// Note that the host name may be shorter than the buffer, so we need to pad
// with nulls
uint8_t hostBytesBuffer[MPI_HOST_STATE_LEN];
memset(hostBytesBuffer, '\0', MPI_HOST_STATE_LEN);
::strcpy((char*)hostBytesBuffer, thisHost.c_str());

const std::shared_ptr<state::StateKeyValue>& kv = getRankHostState(rank);
kv->set(hostBytesBuffer);
kv->pushFull();
}

std::string MpiWorld::getHostForRank(int rank)
{
// Pull from state if not present
if (rankHostMap.find(rank) == rankHostMap.end()) {
faabric::util::FullLock lock(worldMutex);

if (rankHostMap.find(rank) == rankHostMap.end()) {
auto buffer = new uint8_t[MPI_HOST_STATE_LEN];
const std::shared_ptr<state::StateKeyValue>& kv =
getRankHostState(rank);
kv->get(buffer);

char* bufferChar = reinterpret_cast<char*>(buffer);
if (bufferChar[0] == '\0') {
// No entry for other rank
throw std::runtime_error(
fmt::format("No host entry for rank {}", rank));
}

// Note - we rely on C strings detecting the null terminator here,
// assuming the host will either be an IP or string of alphanumeric
// characters and dots
std::string otherHost(bufferChar);
rankHostMap[rank] = otherHost;
}
logger->error("No known host for rank {}", rank);
throw std::runtime_error("No known host for rank");
}

{
Expand All @@ -272,6 +208,24 @@ std::string MpiWorld::getHostForRank(int rank)
}
}

// Prepare the host-rank map with a vector containing _all_ ranks
void MpiWorld::setAllRankHosts(const faabric::MpiHostsToRanksMessage& msg)
{
assert(msg.hosts().size() == size);
faabric::util::FullLock lock(worldMutex);
for (int i = 0; i < size; i++) {
auto it = rankHostMap.try_emplace(i, msg.hosts().at(i));
if (!it.second) {
logger->error("Tried to map host ({}) to rank ({}), but rank was "
"already mapped to host ({})",
msg.hosts().at(i),
i + 1,
rankHostMap[i]);
throw std::runtime_error("Rank already mapped to host");
}
}
}

void MpiWorld::getCartesianRank(int rank,
int maxDims,
const int* dims,
Expand Down
6 changes: 3 additions & 3 deletions src/scheduler/MpiWorldRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ scheduler::MpiWorld& MpiWorldRegistry::createWorld(const faabric::Message& msg,
return worldMap[worldId];
}

MpiWorld& MpiWorldRegistry::getOrInitialiseWorld(const faabric::Message& msg,
int worldId)
MpiWorld& MpiWorldRegistry::getOrInitialiseWorld(const faabric::Message& msg)
{
// Create world locally if not exists
int worldId = msg.mpiworldid();
if (worldMap.find(worldId) == worldMap.end()) {
faabric::util::FullLock lock(registryMutex);
if (worldMap.find(worldId) == worldMap.end()) {
MpiWorld& world = worldMap[worldId];
world.initialiseFromState(msg, worldId);
world.initialiseFromMsg(msg);
}
}

Expand Down
Loading