diff --git a/include/faabric/scheduler/MpiContext.h b/include/faabric/scheduler/MpiContext.h index 9a01f9af4..d898793f1 100644 --- a/include/faabric/scheduler/MpiContext.h +++ b/include/faabric/scheduler/MpiContext.h @@ -9,7 +9,7 @@ class MpiContext public: MpiContext(); - void createWorld(const faabric::Message& msg); + int createWorld(const faabric::Message& msg); void joinWorld(const faabric::Message& msg); diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index 68a78a3b2..39a788f4e 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -15,15 +15,8 @@ namespace faabric::scheduler { typedef faabric::util::Queue> InMemoryMpiQueue; -struct MpiWorldState -{ - int worldSize; -}; - std::string getWorldStateKey(int worldId); -std::string getRankStateKey(int worldId, int rankId); - class MpiWorld { public: @@ -31,12 +24,13 @@ class MpiWorld void create(const faabric::Message& call, int newId, int newSize); - void initialiseFromState(const faabric::Message& msg, int worldId); - - void registerRank(int rank); + void initialiseFromMsg(const faabric::Message& msg, + bool forceLocal = false); std::string getHostForRank(int rank); + void setAllRankHosts(const faabric::MpiHostsToRanksMessage& msg); + std::string getUser(); std::string getFunction(); @@ -238,17 +232,11 @@ class MpiWorld std::vector cartProcsPerDim; - void setUpStateKV(); - - std::shared_ptr getRankHostState(int rank); - faabric::scheduler::FunctionCallClient& getFunctionCallClient( const std::string& otherHost); void checkRankOnThisHost(int rank); - void pushToState(); - void closeThreadLocalClients(); }; } diff --git a/include/faabric/scheduler/MpiWorldRegistry.h b/include/faabric/scheduler/MpiWorldRegistry.h index ea3aa9ff7..303c099df 100644 --- a/include/faabric/scheduler/MpiWorldRegistry.h +++ b/include/faabric/scheduler/MpiWorldRegistry.h @@ -12,8 +12,7 @@ class MpiWorldRegistry int worldId, std::string hostOverride = ""); - scheduler::MpiWorld& getOrInitialiseWorld(const faabric::Message& msg, - int worldId); + scheduler::MpiWorld& getOrInitialiseWorld(const faabric::Message& msg); scheduler::MpiWorld& getWorld(int worldId); diff --git a/include/faabric/transport/MpiMessageEndpoint.h b/include/faabric/transport/MpiMessageEndpoint.h new file mode 100644 index 000000000..252a75bcc --- /dev/null +++ b/include/faabric/transport/MpiMessageEndpoint.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include +#include +#include + +namespace faabric::transport { +faabric::MpiHostsToRanksMessage recvMpiHostRankMsg(); + +void sendMpiHostRankMsg(const std::string& hostIn, + const faabric::MpiHostsToRanksMessage msg); +} diff --git a/include/faabric/transport/common.h b/include/faabric/transport/common.h index 1268fc48f..ad8b84f40 100644 --- a/include/faabric/transport/common.h +++ b/include/faabric/transport/common.h @@ -8,3 +8,5 @@ #define MPI_MESSAGE_PORT 8005 #define SNAPSHOT_PORT 8006 #define REPLY_PORT_OFFSET 100 + +#define MPI_PORT 8800 diff --git a/src/mpi_native/mpi_native.cpp b/src/mpi_native/mpi_native.cpp index 1f06d616e..d9d0faa19 100644 --- a/src/mpi_native/mpi_native.cpp +++ b/src/mpi_native/mpi_native.cpp @@ -28,10 +28,9 @@ faabric::Message* getExecutingCall() faabric::scheduler::MpiWorld& getExecutingWorld() { - int worldId = executingContext.getWorldId(); faabric::scheduler::MpiWorldRegistry& reg = faabric::scheduler::getMpiWorldRegistry(); - return reg.getOrInitialiseWorld(*getExecutingCall(), worldId); + return reg.getOrInitialiseWorld(*getExecutingCall()); } static void notImplemented(const std::string& funcName) diff --git a/src/proto/faabric.proto b/src/proto/faabric.proto index efd31d201..55b14b614 100644 --- a/src/proto/faabric.proto +++ b/src/proto/faabric.proto @@ -84,6 +84,13 @@ message MPIMessage { bytes buffer = 8; } +// Instead of sending a map, or a list of ranks, we use the repeated string +// index as rank. Note that protobuf guarantess in-order delivery of repeated +// fields. +message MpiHostsToRanksMessage { + repeated string hosts = 1; +} + message Message { int32 id = 1; int32 appId = 2; diff --git a/src/scheduler/MpiContext.cpp b/src/scheduler/MpiContext.cpp index 130bcc60a..c9ce8e77b 100644 --- a/src/scheduler/MpiContext.cpp +++ b/src/scheduler/MpiContext.cpp @@ -12,7 +12,7 @@ MpiContext::MpiContext() , worldId(-1) {} -void MpiContext::createWorld(const faabric::Message& msg) +int MpiContext::createWorld(const faabric::Message& msg) { const std::shared_ptr& logger = faabric::util::getLogger(); @@ -32,6 +32,9 @@ void MpiContext::createWorld(const faabric::Message& msg) // Set up this context isMpi = true; rank = 0; + + // Return the world id to store it in the original message + return worldId; } void MpiContext::joinWorld(const faabric::Message& msg) @@ -47,8 +50,7 @@ void MpiContext::joinWorld(const faabric::Message& msg) // Register with the world MpiWorldRegistry& registry = getMpiWorldRegistry(); - MpiWorld& world = registry.getOrInitialiseWorld(msg, worldId); - world.registerRank(rank); + registry.getOrInitialiseWorld(msg); } bool MpiContext::getIsMpi() diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 3f0c295bd..3054490c8 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -35,39 +36,12 @@ std::string getWorldStateKey(int worldId) return "mpi_world_" + std::to_string(worldId); } -std::string getRankStateKey(int worldId, int rankId) -{ - if (worldId <= 0 || rankId < 0) { - throw std::runtime_error( - fmt::format("World ID must be >0 and rank ID must be >=0 ({}, {})", - worldId, - rankId)); - } - return "mpi_rank_" + std::to_string(worldId) + "_" + std::to_string(rankId); -} - std::string getWindowStateKey(int worldId, int rank, size_t size) { return "mpi_win_" + std::to_string(worldId) + "_" + std::to_string(rank) + "_" + std::to_string(size); } -void MpiWorld::setUpStateKV() -{ - if (stateKV == nullptr) { - state::State& state = state::getGlobalState(); - std::string stateKey = getWorldStateKey(id); - stateKV = state.getKV(user, stateKey, sizeof(MpiWorldState)); - } -} - -std::shared_ptr MpiWorld::getRankHostState(int rank) -{ - state::State& state = state::getGlobalState(); - std::string stateKey = getRankStateKey(id, rank); - return state.getKV(user, stateKey, MPI_HOST_STATE_LEN); -} - faabric::scheduler::FunctionCallClient& MpiWorld::getFunctionCallClient( const std::string& otherHost) { @@ -114,13 +88,6 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize) threadPool = std::make_shared( getMpiThreadPoolSize()); - // Write this to state - setUpStateKV(); - pushToState(); - - // Register this as the master - registerRank(0); - auto& sch = faabric::scheduler::getScheduler(); // Dispatch all the chained calls @@ -133,9 +100,29 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize) msg.set_ismpi(true); msg.set_mpiworldid(id); msg.set_mpirank(i + 1); + msg.set_mpiworldsize(size); } - sch.callFunctions(req); + // Send the init messages (note that message i corresponds to rank i+1) + std::vector executedAt = sch.callFunctions(req); + assert(executedAt.size() == size - 1); + + // Prepend this host for rank 0 + executedAt.insert(executedAt.begin(), thisHost); + + // Register hosts to rank mappings on this host + faabric::MpiHostsToRanksMessage hostRankMsg; + *hostRankMsg.mutable_hosts() = { executedAt.begin(), executedAt.end() }; + setAllRankHosts(hostRankMsg); + + // Set up a list of hosts to broadcast to (excluding this host) + std::set hosts(executedAt.begin(), executedAt.end()); + hosts.erase(thisHost); + + // Do the broadcast + for (const auto& h : hosts) { + faabric::transport::sendMpiHostRankMsg(h, hostRankMsg); + } } void MpiWorld::destroy() @@ -147,12 +134,6 @@ void MpiWorld::destroy() // Note - we are deliberately not deleting the KV in the global state // TODO - find a way to do this only from the master client - for (auto& s : rankHostMap) { - const std::shared_ptr& rankState = - getRankHostState(s.first); - state::getGlobalState().deleteKV(rankState->user, rankState->key); - } - // Wait (forever) until all ranks are done consuming their queues to // clear them. // Note - this means that an application with outstanding messages, i.e. @@ -193,77 +174,32 @@ void MpiWorld::closeThreadLocalClients() functionCallClients.clear(); } -void MpiWorld::initialiseFromState(const faabric::Message& msg, int worldId) +void MpiWorld::initialiseFromMsg(const faabric::Message& msg, bool forceLocal) { - id = worldId; + id = msg.mpiworldid(); user = msg.user(); function = msg.function(); + size = msg.mpiworldsize(); - setUpStateKV(); - - // Read from state - MpiWorldState s{}; - stateKV->pull(); - stateKV->get(BYTES(&s)); - size = s.worldSize; threadPool = std::make_shared( getMpiThreadPoolSize()); -} - -void MpiWorld::pushToState() -{ - // Write to state - MpiWorldState s{ - .worldSize = this->size, - }; - - stateKV->set(BYTES(&s)); - stateKV->pushFull(); -} -void MpiWorld::registerRank(int rank) -{ - { - faabric::util::FullLock lock(worldMutex); - rankHostMap[rank] = thisHost; + // Sometimes for testing purposes we may want to initialise a world in the + // _same_ host we have created one (note that this would never happen in + // reality). If so, we skip the rank-host map broadcasting. + if (!forceLocal) { + // Block until we receive + faabric::MpiHostsToRanksMessage hostRankMsg = + faabric::transport::recvMpiHostRankMsg(); + setAllRankHosts(hostRankMsg); } - - // Note that the host name may be shorter than the buffer, so we need to pad - // with nulls - uint8_t hostBytesBuffer[MPI_HOST_STATE_LEN]; - memset(hostBytesBuffer, '\0', MPI_HOST_STATE_LEN); - ::strcpy((char*)hostBytesBuffer, thisHost.c_str()); - - const std::shared_ptr& kv = getRankHostState(rank); - kv->set(hostBytesBuffer); - kv->pushFull(); } std::string MpiWorld::getHostForRank(int rank) { - // Pull from state if not present if (rankHostMap.find(rank) == rankHostMap.end()) { - faabric::util::FullLock lock(worldMutex); - - if (rankHostMap.find(rank) == rankHostMap.end()) { - auto buffer = new uint8_t[MPI_HOST_STATE_LEN]; - const std::shared_ptr& kv = - getRankHostState(rank); - kv->get(buffer); - - char* bufferChar = reinterpret_cast(buffer); - if (bufferChar[0] == '\0') { - // No entry for other rank - throw std::runtime_error( - fmt::format("No host entry for rank {}", rank)); - } - - // Note - we rely on C strings detecting the null terminator here, - // assuming the host will either be an IP or string of alphanumeric - // characters and dots - std::string otherHost(bufferChar); - rankHostMap[rank] = otherHost; - } + logger->error("No known host for rank {}", rank); + throw std::runtime_error("No known host for rank"); } { @@ -272,6 +208,24 @@ std::string MpiWorld::getHostForRank(int rank) } } +// Prepare the host-rank map with a vector containing _all_ ranks +void MpiWorld::setAllRankHosts(const faabric::MpiHostsToRanksMessage& msg) +{ + assert(msg.hosts().size() == size); + faabric::util::FullLock lock(worldMutex); + for (int i = 0; i < size; i++) { + auto it = rankHostMap.try_emplace(i, msg.hosts().at(i)); + if (!it.second) { + logger->error("Tried to map host ({}) to rank ({}), but rank was " + "already mapped to host ({})", + msg.hosts().at(i), + i + 1, + rankHostMap[i]); + throw std::runtime_error("Rank already mapped to host"); + } + } +} + void MpiWorld::getCartesianRank(int rank, int maxDims, const int* dims, diff --git a/src/scheduler/MpiWorldRegistry.cpp b/src/scheduler/MpiWorldRegistry.cpp index e04e2b6fc..6e07431e2 100644 --- a/src/scheduler/MpiWorldRegistry.cpp +++ b/src/scheduler/MpiWorldRegistry.cpp @@ -38,15 +38,15 @@ scheduler::MpiWorld& MpiWorldRegistry::createWorld(const faabric::Message& msg, return worldMap[worldId]; } -MpiWorld& MpiWorldRegistry::getOrInitialiseWorld(const faabric::Message& msg, - int worldId) +MpiWorld& MpiWorldRegistry::getOrInitialiseWorld(const faabric::Message& msg) { // Create world locally if not exists + int worldId = msg.mpiworldid(); if (worldMap.find(worldId) == worldMap.end()) { faabric::util::FullLock lock(registryMutex); if (worldMap.find(worldId) == worldMap.end()) { MpiWorld& world = worldMap[worldId]; - world.initialiseFromState(msg, worldId); + world.initialiseFromMsg(msg); } } diff --git a/src/transport/CMakeLists.txt b/src/transport/CMakeLists.txt index 15cf39fc5..3ac400f19 100644 --- a/src/transport/CMakeLists.txt +++ b/src/transport/CMakeLists.txt @@ -10,6 +10,7 @@ set(HEADERS "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageEndpoint.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageEndpointClient.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageEndpointServer.h" + "${FAABRIC_INCLUDE_DIR}/faabric/transport/MpiMessageEndpoint.h" ) set(LIB_FILES @@ -18,6 +19,7 @@ set(LIB_FILES MessageEndpoint.cpp MessageEndpointClient.cpp MessageEndpointServer.cpp + MpiMessageEndpoint.cpp ${HEADERS} ) diff --git a/src/transport/MpiMessageEndpoint.cpp b/src/transport/MpiMessageEndpoint.cpp new file mode 100644 index 000000000..0076ccc65 --- /dev/null +++ b/src/transport/MpiMessageEndpoint.cpp @@ -0,0 +1,30 @@ +#include + +namespace faabric::transport { +faabric::MpiHostsToRanksMessage recvMpiHostRankMsg() +{ + faabric::transport::RecvMessageEndpoint endpoint(MPI_PORT); + endpoint.open(faabric::transport::getGlobalMessageContext()); + faabric::transport::Message m = endpoint.recv(); + PARSE_MSG(faabric::MpiHostsToRanksMessage, m.data(), m.size()); + endpoint.close(); + + return msg; +} + +void sendMpiHostRankMsg(const std::string& hostIn, + const faabric::MpiHostsToRanksMessage msg) +{ + size_t msgSize = msg.ByteSizeLong(); + { + uint8_t sMsg[msgSize]; + if (!msg.SerializeToArray(sMsg, msgSize)) { + throw std::runtime_error("Error serialising message"); + } + faabric::transport::SendMessageEndpoint endpoint(hostIn, MPI_PORT); + endpoint.open(faabric::transport::getGlobalMessageContext()); + endpoint.send(sMsg, msgSize, false); + endpoint.close(); + } +} +} diff --git a/tests/test/scheduler/test_function_client_server.cpp b/tests/test/scheduler/test_function_client_server.cpp index 0041f2d5f..f90684f77 100644 --- a/tests/test/scheduler/test_function_client_server.cpp +++ b/tests/test/scheduler/test_function_client_server.cpp @@ -53,9 +53,27 @@ class ClientServerFixture TEST_CASE_METHOD(ClientServerFixture, "Test sending MPI message", "[scheduler]") { - // Create an MPI world on this host and one on a "remote" host - std::string otherHost = "192.168.9.2"; + auto& sch = faabric::scheduler::getScheduler(); + + // Force the scheduler to initialise a world in the remote host by setting + // a world size bigger than the slots available locally + int worldSize = 2; + faabric::HostResources localResources; + localResources.set_slots(1); + localResources.set_usedslots(1); + faabric::HostResources otherResources; + otherResources.set_slots(1); + + // Set up a remote host + std::string otherHost = LOCALHOST; + sch.addHostToGlobalSet(otherHost); + + // Mock everything to make sure the other host has resources as well + faabric::util::setMockMode(true); + sch.setThisHostResources(localResources); + faabric::scheduler::queueResourceResponse(otherHost, otherResources); + // Create an MPI world on this host and one on a "remote" host const char* user = "mpi"; const char* func = "hellompi"; int worldId = 123; @@ -63,22 +81,22 @@ TEST_CASE_METHOD(ClientServerFixture, "Test sending MPI message", "[scheduler]") msg.set_user(user); msg.set_function(func); msg.set_mpiworldid(worldId); - msg.set_mpiworldsize(2); + msg.set_mpiworldsize(worldSize); faabric::util::messageFactory(user, func); scheduler::MpiWorldRegistry& registry = getMpiWorldRegistry(); - scheduler::MpiWorld& localWorld = - registry.createWorld(msg, worldId, LOCALHOST); + scheduler::MpiWorld& localWorld = registry.createWorld(msg, worldId); scheduler::MpiWorld remoteWorld; remoteWorld.overrideHost(otherHost); - remoteWorld.initialiseFromState(msg, worldId); + remoteWorld.initialiseFromMsg(msg); // Register a rank on each int rankLocal = 0; int rankRemote = 1; - localWorld.registerRank(rankLocal); - remoteWorld.registerRank(rankRemote); + + // Undo the mocking, so we actually send the MPI message + faabric::util::setMockMode(false); // Create a message faabric::MPIMessage mpiMsg; diff --git a/tests/test/scheduler/test_mpi_context.cpp b/tests/test/scheduler/test_mpi_context.cpp index addf1a045..d6a547257 100644 --- a/tests/test/scheduler/test_mpi_context.cpp +++ b/tests/test/scheduler/test_mpi_context.cpp @@ -29,6 +29,7 @@ TEST_CASE("Check world creation", "[mpi]") // Check a new world ID is created int worldId = c.getWorldId(); REQUIRE(worldId > 0); + msg.set_mpiworldid(worldId); // Check this context is set up REQUIRE(c.getIsMpi()); @@ -36,13 +37,12 @@ TEST_CASE("Check world creation", "[mpi]") // Get the world and check it is set up MpiWorldRegistry& reg = getMpiWorldRegistry(); - MpiWorld& world = reg.getOrInitialiseWorld(msg, worldId); + MpiWorld& world = reg.getOrInitialiseWorld(msg); REQUIRE(world.getId() == worldId); REQUIRE(world.getSize() == 10); REQUIRE(world.getUser() == "mpi"); REQUIRE(world.getFunction() == "hellompi"); - world.destroy(); tearDown(world); } @@ -84,10 +84,11 @@ TEST_CASE("Check default world size is set", "[mpi]") msg.set_mpiworldsize(requestedWorldSize); c.createWorld(msg); int worldId = c.getWorldId(); + msg.set_mpiworldid(worldId); // Check that the size is set to the default MpiWorldRegistry& reg = getMpiWorldRegistry(); - MpiWorld& world = reg.getOrInitialiseWorld(msg, worldId); + MpiWorld& world = reg.getOrInitialiseWorld(msg); REQUIRE(world.getSize() == defaultWorldSize); // Reset config @@ -131,8 +132,9 @@ TEST_CASE("Check joining world", "[mpi]") // Check rank is registered to this host MpiWorldRegistry& reg = getMpiWorldRegistry(); - MpiWorld& world = reg.getOrInitialiseWorld(msgB, worldId); + MpiWorld& world = reg.getOrInitialiseWorld(msgB); const std::string actualHost = world.getHostForRank(1); + REQUIRE(actualHost == expectedHost); tearDown(world); } diff --git a/tests/test/scheduler/test_mpi_world.cpp b/tests/test/scheduler/test_mpi_world.cpp index 22b15572d..4bf4b3fa8 100644 --- a/tests/test/scheduler/test_mpi_world.cpp +++ b/tests/test/scheduler/test_mpi_world.cpp @@ -55,6 +55,7 @@ TEST_CASE("Test world creation", "[mpi]") REQUIRE(actualCall.ismpi()); REQUIRE(actualCall.mpiworldid() == worldId); REQUIRE(actualCall.mpirank() == i + 1); + REQUIRE(actualCall.mpiworldsize() == worldSize); } // Check that this host is registered as the master @@ -64,18 +65,23 @@ TEST_CASE("Test world creation", "[mpi]") tearDown({ &world }); } -TEST_CASE("Test world loading from state", "[mpi]") +TEST_CASE("Test world loading from msg", "[mpi]") { cleanFaabric(); // Create a world - const faabric::Message& msg = faabric::util::messageFactory(user, func); + faabric::Message msg = faabric::util::messageFactory(user, func); scheduler::MpiWorld worldA; worldA.create(msg, worldId, worldSize); // Create another copy from state scheduler::MpiWorld worldB; - worldB.initialiseFromState(msg, worldId); + // These would be set by the master rank, when invoking other ranks + msg.set_mpiworldsize(worldSize); + msg.set_mpiworldid(worldId); + // Force creating the second world in the _same_ host + bool forceLocal = true; + worldB.initialiseFromMsg(msg, forceLocal); REQUIRE(worldB.getSize() == worldSize); REQUIRE(worldB.getId() == worldId); @@ -85,42 +91,49 @@ TEST_CASE("Test world loading from state", "[mpi]") tearDown({ &worldA, &worldB }); } -TEST_CASE("Test registering a rank", "[mpi]") +TEST_CASE("Test rank allocation", "[mpi]") { cleanFaabric(); - // Note, we deliberately make the host names different lengths, - // shorter than the buffer - std::string hostA = faabric::util::randomString(MPI_HOST_STATE_LEN - 5); - std::string hostB = faabric::util::randomString(MPI_HOST_STATE_LEN - 10); + auto& sch = faabric::scheduler::getScheduler(); - // Create a world - const faabric::Message& msg = faabric::util::messageFactory(user, func); - scheduler::MpiWorld worldA; - worldA.overrideHost(hostA); - worldA.create(msg, worldId, worldSize); + // Force the scheduler to initialise a world in the remote host by setting + // a worldSize bigger than the slots available locally + int worldSize = 2; + faabric::HostResources localResources; + localResources.set_slots(1); + localResources.set_usedslots(1); + faabric::HostResources otherResources; + otherResources.set_slots(1); - // Register a rank to this host and check - int rankA = 5; - worldA.registerRank(5); - const std::string actualHost = worldA.getHostForRank(0); - REQUIRE(actualHost == hostA); + std::string thisHost = faabric::util::getSystemConfig().endpointHost; + std::string otherHost = LOCALHOST; + sch.addHostToGlobalSet(otherHost); - // Create a new instance of the world with a new host ID - scheduler::MpiWorld worldB; - worldB.overrideHost(hostB); - worldB.initialiseFromState(msg, worldId); + // Mock everything to make sure the other host has resources as well + faabric::util::setMockMode(true); + sch.setThisHostResources(localResources); + faabric::scheduler::queueResourceResponse(otherHost, otherResources); + + // Create a world + faabric::Message msg = faabric::util::messageFactory(user, func); + msg.set_mpiworldid(worldId); + msg.set_mpiworldsize(worldSize); - int rankB = 4; - worldB.registerRank(4); + // Create the local world + scheduler::MpiWorld& localWorld = + getMpiWorldRegistry().createWorld(msg, worldId); + + scheduler::MpiWorld remoteWorld; + remoteWorld.overrideHost(otherHost); + remoteWorld.initialiseFromMsg(msg); // Now check both world instances report the same mappings - REQUIRE(worldA.getHostForRank(rankA) == hostA); - REQUIRE(worldA.getHostForRank(rankB) == hostB); - REQUIRE(worldB.getHostForRank(rankA) == hostA); - REQUIRE(worldB.getHostForRank(rankB) == hostB); + REQUIRE(localWorld.getHostForRank(0) == thisHost); + REQUIRE(localWorld.getHostForRank(1) == otherHost); - tearDown({ &worldA, &worldB }); + faabric::util::setMockMode(false); + tearDown({ &localWorld, &remoteWorld }); } TEST_CASE("Test cartesian communicator", "[mpi]") @@ -173,8 +186,8 @@ TEST_CASE("Test cartesian communicator", "[mpi]") }; } - MpiWorld& world = - getMpiWorldRegistry().createWorld(msg, worldId, LOCALHOST); + scheduler::MpiWorld world; + world.create(msg, worldId, worldSize); // Get coordinates from rank for (int i = 0; i < worldSize; i++) { @@ -245,17 +258,14 @@ TEST_CASE("Test send and recv on same host", "[mpi]") { cleanFaabric(); - const faabric::Message& msg = faabric::util::messageFactory(user, func); + faabric::Message msg = faabric::util::messageFactory(user, func); + msg.set_mpiworldsize(2); scheduler::MpiWorld world; world.create(msg, worldId, worldSize); - // Register two ranks - int rankA1 = 1; - int rankA2 = 2; - world.registerRank(rankA1); - world.registerRank(rankA2); - // Send a message between colocated ranks + int rankA1 = 0; + int rankA2 = 1; std::vector messageData = { 0, 1, 2 }; world.send( rankA1, rankA2, BYTES(messageData.data()), MPI_INT, messageData.size()); @@ -315,13 +325,9 @@ TEST_CASE("Test sendrecv", "[mpi]") scheduler::MpiWorld world; world.create(msg, worldId, worldSize); - // Register two ranks + // Prepare data int rankA = 1; int rankB = 2; - world.registerRank(rankA); - world.registerRank(rankB); - - // Prepare data MPI_Status status{}; std::vector messageDataAB = { 0, 1, 2 }; std::vector messageDataBA = { 3, 2, 1, 0 }; @@ -381,11 +387,8 @@ TEST_CASE("Test ring sendrecv", "[mpi]") scheduler::MpiWorld world; world.create(msg, worldId, worldSize); - // Register five processes (0 already registered) + // Use five processes std::vector ranks = { 0, 1, 2, 3, 4 }; - for (int i = 1; i < ranks.size(); i++) { - world.registerRank(ranks[i]); - } // Prepare data MPI_Status status{}; @@ -432,13 +435,9 @@ TEST_CASE("Test async send and recv", "[mpi]") scheduler::MpiWorld world; world.create(msg, worldId, worldSize); - // Register two ranks + // Send a couple of async messages (from both to each other) int rankA = 1; int rankB = 2; - world.registerRank(rankA); - world.registerRank(rankB); - - // Send a couple of async messages (from both to each other) std::vector messageDataA = { 0, 1, 2 }; std::vector messageDataB = { 3, 4, 5, 6 }; int sendIdA = world.isend( @@ -475,41 +474,61 @@ TEST_CASE("Test send across hosts", "[mpi]") server.start(); usleep(1000 * 100); + auto& sch = faabric::scheduler::getScheduler(); + + // Force the scheduler to initialise a world in the remote host by setting + // a worldSize bigger than the slots available locally + int worldSize = 2; + faabric::HostResources localResources; + localResources.set_slots(1); + localResources.set_usedslots(1); + faabric::HostResources otherResources; + otherResources.set_slots(1); + + // Set up a remote host + std::string otherHost = LOCALHOST; + sch.addHostToGlobalSet(otherHost); + + // Mock everything to make sure the other host has resources as well + faabric::util::setMockMode(true); + sch.setThisHostResources(localResources); + faabric::scheduler::queueResourceResponse(otherHost, otherResources); + // Set up the world on this host faabric::Message msg = faabric::util::messageFactory(user, func); msg.set_mpiworldid(worldId); msg.set_mpiworldsize(worldSize); + // Create the local world scheduler::MpiWorld& localWorld = - getMpiWorldRegistry().createWorld(msg, worldId, LOCALHOST); + getMpiWorldRegistry().createWorld(msg, worldId); - // Set up a world on the "remote" host - std::string otherHost = faabric::util::randomString(MPI_HOST_STATE_LEN - 3); scheduler::MpiWorld remoteWorld; remoteWorld.overrideHost(otherHost); - remoteWorld.initialiseFromState(msg, worldId); + remoteWorld.initialiseFromMsg(msg); // Register two ranks (one on each host) - int rankA = 1; - int rankB = 2; - remoteWorld.registerRank(rankA); - localWorld.registerRank(rankB); + int rankA = 0; + int rankB = 1; std::vector messageData = { 0, 1, 2 }; + // Undo the mocking, so we actually send the MPI message + faabric::util::setMockMode(false); + // Send a message that should get sent to this host remoteWorld.send( - rankA, rankB, BYTES(messageData.data()), MPI_INT, messageData.size()); + rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size()); usleep(1000 * 100); SECTION("Check queueing") { - REQUIRE(localWorld.getLocalQueueSize(rankA, rankB) == 1); + REQUIRE(localWorld.getLocalQueueSize(rankB, rankA) == 1); // Check message content faabric::MPIMessage actualMessage = - *(localWorld.getLocalQueue(rankA, rankB)->dequeue()); - checkMessage(actualMessage, rankA, rankB, messageData); + *(localWorld.getLocalQueue(rankB, rankA)->dequeue()); + checkMessage(actualMessage, rankB, rankA, messageData); } SECTION("Check recv") @@ -518,12 +537,12 @@ TEST_CASE("Test send across hosts", "[mpi]") MPI_Status status{}; auto buffer = new int[messageData.size()]; localWorld.recv( - rankA, rankB, BYTES(buffer), MPI_INT, messageData.size(), &status); + rankB, rankA, BYTES(buffer), MPI_INT, messageData.size(), &status); std::vector actual(buffer, buffer + messageData.size()); REQUIRE(actual == messageData); - REQUIRE(status.MPI_SOURCE == rankA); + REQUIRE(status.MPI_SOURCE == rankB); REQUIRE(status.MPI_ERROR == MPI_SUCCESS); REQUIRE(status.bytesSize == messageData.size() * sizeof(int)); } @@ -541,15 +560,8 @@ TEST_CASE("Test send/recv message with no data", "[mpi]") scheduler::MpiWorld world; world.create(msg, worldId, worldSize); - // Register two ranks int rankA1 = 1; int rankA2 = 2; - world.registerRank(rankA1); - world.registerRank(rankA2); - - // Check we know the number of state keys - state::State& state = state::getGlobalState(); - REQUIRE(state.getKVCount() == 4); // Send a message between colocated ranks std::vector messageData = { 0 }; @@ -562,9 +574,6 @@ TEST_CASE("Test send/recv message with no data", "[mpi]") *(world.getLocalQueue(rankA1, rankA2)->dequeue()); REQUIRE(actualMessage.count() == 0); REQUIRE(actualMessage.type() == FAABRIC_INT); - - // Check no extra data in state - REQUIRE(state.getKVCount() == 4); } SECTION("Check receiving with null ptr") @@ -573,8 +582,6 @@ TEST_CASE("Test send/recv message with no data", "[mpi]") MPI_Status status{}; world.recv(rankA1, rankA2, nullptr, MPI_INT, 0, &status); - // Check no extra data in state - REQUIRE(state.getKVCount() == 4); REQUIRE(status.MPI_SOURCE == rankA1); REQUIRE(status.MPI_ERROR == MPI_SUCCESS); REQUIRE(status.bytesSize == 0); @@ -591,9 +598,6 @@ TEST_CASE("Test recv with partial data", "[mpi]") scheduler::MpiWorld world; world.create(msg, worldId, worldSize); - world.registerRank(1); - world.registerRank(2); - // Send a message with size less than the recipient is expecting std::vector messageData = { 0, 1, 2, 3 }; unsigned long actualSize = messageData.size(); @@ -621,9 +625,6 @@ TEST_CASE("Test probe", "[mpi]") scheduler::MpiWorld world; world.create(msg, worldId, worldSize); - world.registerRank(1); - world.registerRank(2); - // Send two messages of different sizes std::vector messageData = { 0, 1, 2, 3, 4, 5, 6 }; unsigned long sizeA = 2; @@ -668,35 +669,44 @@ TEST_CASE("Test can't get in-memory queue for non-local ranks", "[mpi]") { cleanFaabric(); - std::string hostA = faabric::util::randomString(MPI_HOST_STATE_LEN - 5); - std::string hostB = faabric::util::randomString(MPI_HOST_STATE_LEN - 3); + std::string otherHost = LOCALHOST; - const faabric::Message& msg = faabric::util::messageFactory(user, func); + auto& sch = faabric::scheduler::getScheduler(); + + // Force the scheduler to initialise a world in the remote host by setting + // a worldSize bigger than the slots available locally + int worldSize = 4; + faabric::HostResources localResources; + localResources.set_slots(2); + localResources.set_usedslots(1); + faabric::HostResources otherResources; + otherResources.set_slots(2); + + // Set up a remote host + sch.addHostToGlobalSet(otherHost); + + // Mock everything to make sure the other host has resources as well + faabric::util::setMockMode(true); + sch.setThisHostResources(localResources); + faabric::scheduler::queueResourceResponse(otherHost, otherResources); + + faabric::Message msg = faabric::util::messageFactory(user, func); + msg.set_mpiworldsize(worldSize); scheduler::MpiWorld worldA; - worldA.overrideHost(hostA); worldA.create(msg, worldId, worldSize); scheduler::MpiWorld worldB; - worldB.overrideHost(hostB); - worldB.initialiseFromState(msg, worldId); - - // Register one rank on each host - int rankA = 1; - int rankB = 2; - worldA.registerRank(rankA); - worldB.registerRank(rankB); - - // Check we can't access unregistered rank on either - REQUIRE_THROWS(worldA.getLocalQueue(0, 3)); - REQUIRE_THROWS(worldB.getLocalQueue(0, 3)); + worldB.overrideHost(otherHost); + worldB.initialiseFromMsg(msg); // Check that we can't access rank on another host locally - REQUIRE_THROWS(worldA.getLocalQueue(0, rankB)); + REQUIRE_THROWS(worldA.getLocalQueue(0, 2)); // Double check even when we've retrieved the rank - REQUIRE(worldA.getHostForRank(rankB) == hostB); - REQUIRE_THROWS(worldA.getLocalQueue(0, rankB)); + REQUIRE(worldA.getHostForRank(2) == otherHost); + REQUIRE_THROWS(worldA.getLocalQueue(0, 2)); + faabric::util::setMockMode(false); tearDown({ &worldA, &worldB }); } @@ -715,22 +725,6 @@ TEST_CASE("Check sending to invalid rank", "[mpi]") tearDown({ &world }); } -TEST_CASE("Check sending to unregistered rank", "[mpi]") -{ - cleanFaabric(); - - const faabric::Message& msg = faabric::util::messageFactory(user, func); - scheduler::MpiWorld world; - world.create(msg, worldId, worldSize); - - // Rank hasn't yet been registered - int destRank = 2; - std::vector input = { 0, 1 }; - REQUIRE_THROWS(world.send(0, destRank, BYTES(input.data()), MPI_INT, 2)); - - tearDown({ &world }); -} - TEST_CASE("Test collective messaging locally and across hosts", "[mpi]") { cleanFaabric(); @@ -739,33 +733,46 @@ TEST_CASE("Test collective messaging locally and across hosts", "[mpi]") server.start(); usleep(1000 * 100); - std::string otherHost = "123.45.67.8"; + auto& sch = faabric::scheduler::getScheduler(); + // Here we rely on the scheduler running out of resources, and overloading + // the localWorld with ranks 4 and 5 int thisWorldSize = 6; + faabric::HostResources localResources; + localResources.set_slots(1); + localResources.set_usedslots(1); + faabric::HostResources otherResources; + otherResources.set_slots(3); + + // Set up a remote host + std::string otherHost = LOCALHOST; + sch.addHostToGlobalSet(otherHost); + + // Mock everything to make sure the other host has resources as well + faabric::util::setMockMode(true); + sch.setThisHostResources(localResources); + faabric::scheduler::queueResourceResponse(otherHost, otherResources); faabric::Message msg = faabric::util::messageFactory(user, func); msg.set_mpiworldid(worldId); msg.set_mpiworldsize(thisWorldSize); - MpiWorld& localWorld = - getMpiWorldRegistry().createWorld(msg, worldId, LOCALHOST); + MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); scheduler::MpiWorld remoteWorld; - remoteWorld.initialiseFromState(msg, worldId); remoteWorld.overrideHost(otherHost); + remoteWorld.initialiseFromMsg(msg); + + // Unset mock mode to actually send remote MPI messages + faabric::util::setMockMode(false); // Register ranks on both hosts int remoteRankA = 1; int remoteRankB = 2; int remoteRankC = 3; - remoteWorld.registerRank(remoteRankA); - remoteWorld.registerRank(remoteRankB); - remoteWorld.registerRank(remoteRankC); int localRankA = 4; int localRankB = 5; - localWorld.registerRank(localRankA); - localWorld.registerRank(localRankB); // Note that ranks are deliberately out of order std::vector remoteWorldRanks = { remoteRankB, @@ -1051,16 +1058,13 @@ template void doReduceTest(scheduler::MpiWorld& world, TEST_CASE("Test reduce", "[mpi]") { + cleanFaabric(); + const faabric::Message& msg = faabric::util::messageFactory(user, func); scheduler::MpiWorld world; int thisWorldSize = 5; world.create(msg, worldId, thisWorldSize); - // Register the ranks (zero already registered by default - for (int r = 1; r < thisWorldSize; r++) { - world.registerRank(r); - } - // Prepare inputs int root = 3; @@ -1220,11 +1224,6 @@ TEST_CASE("Test operator reduce", "[mpi]") int thisWorldSize = 5; world.create(msg, worldId, thisWorldSize); - // Register the ranks - for (int r = 1; r < thisWorldSize; r++) { - world.registerRank(r); - } - SECTION("Max") { SECTION("Integers") @@ -1408,11 +1407,6 @@ TEST_CASE("Test gather and allgather", "[mpi]") world.create(msg, worldId, thisWorldSize); - // Register the ranks (zero already registered by default - for (int r = 1; r < thisWorldSize; r++) { - world.registerRank(r); - } - // Build up per-rank data and expectation int nPerRank = 3; int gatheredSize = nPerRank * thisWorldSize; @@ -1541,11 +1535,6 @@ TEST_CASE("Test scan", "[mpi]") int count = 3; world.create(msg, worldId, thisWorldSize); - // Register the ranks - for (int r = 1; r < thisWorldSize; r++) { - world.registerRank(r); - } - // Prepare input data std::vector> rankData(thisWorldSize, std::vector(count)); @@ -1605,11 +1594,6 @@ TEST_CASE("Test all-to-all", "[mpi]") int thisWorldSize = 4; world.create(msg, worldId, thisWorldSize); - // Register the ranks - for (int r = 1; r < thisWorldSize; r++) { - world.registerRank(r); - } - // Build inputs and expected int inputs[4][8] = { { 0, 1, 2, 3, 4, 5, 6, 7 }, @@ -1654,19 +1638,39 @@ TEST_CASE("Test all-to-all", "[mpi]") TEST_CASE("Test RMA across hosts", "[mpi]") { cleanFaabric(); - std::string otherHost = "192.168.9.2"; + + auto& sch = faabric::scheduler::getScheduler(); + + // Set up host resources + int worldSize = 5; + faabric::HostResources localResources; + localResources.set_slots(3); + localResources.set_usedslots(1); + faabric::HostResources otherResources; + otherResources.set_slots(2); + + // Set up a remote host + std::string otherHost = LOCALHOST; + sch.addHostToGlobalSet(otherHost); + + // Mock everything to make sure the other host has resources as well + faabric::util::setMockMode(true); + sch.setThisHostResources(localResources); + faabric::scheduler::queueResourceResponse(otherHost, otherResources); faabric::Message msg = faabric::util::messageFactory(user, func); msg.set_mpiworldid(worldId); msg.set_mpiworldsize(worldSize); - MpiWorldRegistry& registry = getMpiWorldRegistry(); scheduler::MpiWorld& localWorld = - registry.createWorld(msg, worldId, LOCALHOST); + getMpiWorldRegistry().createWorld(msg, worldId); scheduler::MpiWorld remoteWorld; remoteWorld.overrideHost(otherHost); - remoteWorld.initialiseFromState(msg, worldId); + remoteWorld.initialiseFromMsg(msg); + + // Undo the mocking, so we actually send remote MPI messages + faabric::util::setMockMode(false); FunctionCallServer server; server.start(); @@ -1674,13 +1678,7 @@ TEST_CASE("Test RMA across hosts", "[mpi]") // Register four ranks int rankA1 = 1; - int rankA2 = 2; int rankB1 = 3; - int rankB2 = 4; - localWorld.registerRank(rankA1); - localWorld.registerRank(rankA2); - remoteWorld.registerRank(rankB1); - remoteWorld.registerRank(rankB2); std::vector dataA1 = { 0, 1, 2, 3 }; int dataCount = (int)dataA1.size();