diff --git a/.gitignore b/.gitignore index 736c321b..eedb5b97 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,6 @@ mpich-3.2/ cmake-build-debug/ .vscode/ +# cmake +build/ + diff --git a/.travis.yml b/.travis.yml index 40af4327..f0a7a996 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,16 +8,29 @@ osx_image: xcode10.2 dist: xenial +language: cpp + # Use Build Matrix to do lint and build seperately env: matrix: - TASK=lint LINT_LANG=cpp - TASK=lint LINT_LANG=python - TASK=doc - - TASK=build + # - TASK=build - TASK=mpi-build - TASK=cmake-test +matrix: + exclude: + - os: osx + env: TASK=lint LINT_LANG=cpp + - os: osx + env: TASK=lint LINT_LANG=python + - os: osx + env: TASK=doc + - os: osx + env: TASK=build + # dependent apt packages addons: apt: diff --git a/CMakeLists.txt b/CMakeLists.txt index 758dbd4e..d668be3f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,11 +20,16 @@ if(R_LIB OR MINGW OR WIN32) CXX_STANDARD_REQUIRED ON POSITION_INDEPENDENT_CODE ON) else() - add_library(rabit src/allreduce_base.cc src/allreduce_robust.cc src/engine.cc src/c_api.cc) - add_library(rabit_base src/allreduce_base.cc src/engine_base.cc src/c_api.cc) + find_package(Threads REQUIRED) add_library(rabit_empty src/engine_empty.cc src/c_api.cc) + add_library(rabit_base src/allreduce_base.cc src/engine_base.cc src/c_api.cc) + + add_library(rabit src/allreduce_base.cc src/allreduce_robust.cc src/engine.cc src/c_api.cc) add_library(rabit_mock_static src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc) add_library(rabit_mock SHARED src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc) + target_link_libraries(rabit Threads::Threads) + target_link_libraries(rabit_mock_static Threads::Threads) + target_link_libraries(rabit_mock Threads::Threads) set(rabit_libs rabit rabit_base rabit_empty rabit_mock rabit_mock_static) set_target_properties(rabit rabit_base rabit_empty rabit_mock rabit_mock_static diff --git a/doc/guide.md b/doc/guide.md index 39a69e9e..7bf50b09 100644 --- a/doc/guide.md +++ b/doc/guide.md @@ -154,6 +154,8 @@ you can also refer to [wormhole](https://github.com/dmlc/wormhole/blob/master/le int main(int argc, char *argv[]) { ... rabit::Init(argc, argv); + // sync on expected model size before load checkpoint, if we pass rabit_bootstrap_cache=true + rabit::Allreduce(&model.size(), 1); // load the latest checked model int version = rabit::LoadCheckPoint(&model); // initialize the model if it is the first version @@ -370,3 +372,12 @@ Allreduce/Broadcast calls after the checkpoint from some alive nodes. This is just a conceptual introduction to rabit's fault tolerance model. The actual implementation is more sophisticated, and can deal with more complicated cases such as multiple nodes failure and node failure during recovery phase. + +Rabit Timeout +--------------- + +In certain cases, rabit cluster may suffer lack of resources to retry failed workers. +Thanks to fault tolerant assumption with infinite retry, it might cause entire cluster hang infinitely. +We introduce sidecar thread which runs when rabit fault tolerant runtime observed allreduce/broadcast errors. +By default, it will wait for 30 mins before all workers program exit. +User can opt-in this feature and change treshold by passing rabit_timeout=true and rabit_timeout_sec=x (in seconds). diff --git a/include/rabit/internal/utils.h b/include/rabit/internal/utils.h index 2df27b74..294f1018 100644 --- a/include/rabit/internal/utils.h +++ b/include/rabit/internal/utils.h @@ -7,6 +7,7 @@ #ifndef RABIT_INTERNAL_UTILS_H_ #define RABIT_INTERNAL_UTILS_H_ #define _CRT_SECURE_NO_WARNINGS +#include #include #include #include @@ -66,6 +67,11 @@ const int kPrintBuffer = 1 << 12; * co-locate in the same process */ extern bool STOP_PROCESS_ON_ERROR; +/* \brief parse config string too bool*/ +inline bool StringToBool(const char* s) { + return strcasecmp(s, "true") == 0 || atoi(s) != 0; +} + #ifndef RABIT_CUSTOMIZE_MSG_ /*! * \brief handling of Assert error, caused by inappropriate input @@ -86,7 +92,7 @@ inline void HandleAssertError(const char *msg) { */ inline void HandleCheckError(const char *msg) { if (STOP_PROCESS_ON_ERROR) { - fprintf(stderr, "%s, shutting down process", msg); + fprintf(stderr, "%s, shutting down process\n", msg); exit(-1); } else { fprintf(stderr, "%s, rabit is configured to keep process running\n", msg); diff --git a/scripts/travis_script.sh b/scripts/travis_script.sh index 85a50642..19128224 100755 --- a/scripts/travis_script.sh +++ b/scripts/travis_script.sh @@ -25,8 +25,9 @@ if [ ${TASK} == "cmake-test" ]; then mkdir build cd build cmake -DRABIT_BUILD_TESTS=ON -DRABIT_BUILD_DMLC=ON -DGTEST_ROOT=${HOME}/.local .. - #unit tests - make + # known osx gtest 1.8 issue + cp ${HOME}/.local/lib/*.dylib . + make -j$(nproc) make test make install || exit -1 cd ../test diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 934142f6..0b8f2990 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -73,7 +73,7 @@ bool AllreduceBase::Init(int argc, char* argv[]) { if (task_id == NULL) { task_id = getenv("mapreduce_task_id"); } - if (hadoop_mode != 0) { + if (hadoop_mode) { utils::Check(task_id != NULL, "hadoop_mode is set but cannot find mapred_task_id"); } @@ -94,7 +94,7 @@ bool AllreduceBase::Init(int argc, char* argv[]) { if (num_task == NULL) { num_task = getenv("mapreduce_job_maps"); } - if (hadoop_mode != 0) { + if (hadoop_mode) { utils::Check(num_task != NULL, "hadoop_mode is set but cannot find mapred_map_tasks"); } @@ -188,7 +188,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) { if (!strcmp(name, "DMLC_TASK_ID")) task_id = val; if (!strcmp(name, "DMLC_ROLE")) dmlc_role = val; if (!strcmp(name, "rabit_world_size")) world_size = atoi(val); - if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val); + if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = utils::StringToBool(val); if (!strcmp(name, "rabit_reduce_ring_mincount")) { reduce_ring_mincount = atoi(val); utils::Assert(reduce_ring_mincount > 0, "rabit_reduce_ring_mincount should be greater than 0"); @@ -209,10 +209,17 @@ void AllreduceBase::SetParam(const char *name, const char *val) { } } if (!strcmp(name, "rabit_bootstrap_cache")) { - rabit_bootstrap_cache = atoi(val); + rabit_bootstrap_cache = utils::StringToBool(val); } if (!strcmp(name, "rabit_debug")) { - rabit_debug = atoi(val); + rabit_debug = utils::StringToBool(val); + } + if (!strcmp(name, "rabit_timeout")) { + rabit_timeout = utils::StringToBool(val); + } + if (!strcmp(name, "rabit_timeout_sec")) { + timeout_sec = atoi(val); + utils::Assert(rabit_timeout > 0, "rabit_timeout_sec should be greater than 0 second"); } } /*! diff --git a/src/allreduce_base.h b/src/allreduce_base.h index d3413b93..00f9d2ef 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -496,7 +496,7 @@ class AllreduceBase : public IEngine { // version number of model int version_number; // whether the job is running in hadoop - int hadoop_mode; + bool hadoop_mode; //---- local data related to link ---- // index of parent link, can be -1, meaning this is root of the tree int parent_index; @@ -543,9 +543,13 @@ class AllreduceBase : public IEngine { // backdoor port int port = 0; // enable bootstrap cache 0 false 1 true - int rabit_bootstrap_cache = 0; + bool rabit_bootstrap_cache = false; // enable detailed logging - int rabit_debug = 0; + bool rabit_debug = false; + // by default, if rabit worker not recover in half an hour exit + int timeout_sec = 1800; + // flag to enable rabit_timeout + bool rabit_timeout = false; }; } // namespace engine } // namespace rabit diff --git a/src/allreduce_mock.h b/src/allreduce_mock.h index ec535483..428cc5c2 100644 --- a/src/allreduce_mock.h +++ b/src/allreduce_mock.h @@ -176,7 +176,7 @@ class AllreduceMock : public AllreduceRobust { if (mock_map.count(key) != 0) { num_trial += 1; // data processing frameworks runs on shared process - utils::Error("[%d]@@@Hit Mock Error:%s\n", rank, name); + _error("[%d]@@@Hit Mock Error:%s ", rank, name); } } }; diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index b9ddf71b..242faee9 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -8,6 +8,8 @@ #define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_DEPRECATE #define NOMINMAX +#include +#include #include #include #include "rabit/internal/io.h" @@ -19,6 +21,7 @@ namespace rabit { namespace engine { + AllreduceRobust::AllreduceRobust(void) { num_local_replica = 0; num_global_replica = 5; @@ -38,7 +41,7 @@ bool AllreduceRobust::Init(int argc, char* argv[]) { if (AllreduceBase::Init(argc, argv)) { // chenqin: alert user opted in experimental feature. if (rabit_bootstrap_cache) utils::HandleLogInfo( - "[EXPERIMENTAL] rabit bootstrap cache has been enabled\n"); + "[EXPERIMENTAL] bootstrap cache has been enabled\n"); checkpoint_loaded = false; if (num_global_replica == 0) { result_buffer_round = -1; @@ -55,24 +58,31 @@ bool AllreduceRobust::Shutdown(void) { try { // need to sync the exec before we shutdown, do a pesudo check point // execute checkpoint, note: when checkpoint existing, load will not happen - utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp, + _assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp, cur_cache_seq), "Shutdown: check point must return true"); // reset result buffer resbuf.Clear(); seq_counter = 0; cachebuf.Clear(); cur_cache_seq = 0; lookupbuf.Clear(); // execute check ack step, load happens here - utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, + _assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp, cur_cache_seq), "Shutdown: check ack must return true"); +// travis ci only osx test hang #if defined (__APPLE__) sleep(1); #endif + shutdown_timeout = true; + if (rabit_timeout_task.valid()) { + rabit_timeout_task.wait(); + _assert(rabit_timeout_task.get(), "expect timeout task return\n"); + } return AllreduceBase::Shutdown(); } catch (const std::exception& e) { fprintf(stderr, "%s\n", e.what()); return false; } } + /*! * \brief set parameters to the engine * \param name parameter name @@ -98,8 +108,8 @@ int AllreduceRobust::SetBootstrapCache(const std::string &key, const void *buf, break; } } - utils::Assert(index == -1, "immutable cache key already exists"); - utils::Assert(type_nbytes*count > 0, "can't set empty cache"); + _assert(index == -1, "immutable cache key already exists"); + _assert(type_nbytes*count > 0, "can't set empty cache"); void* temp = cachebuf.AllocTemp(type_nbytes, count); cachebuf.PushTemp(cur_cache_seq, type_nbytes, count); std::memcpy(temp, buf, type_nbytes*count); @@ -133,9 +143,9 @@ int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf, size_t siz = 0; void* temp = cachebuf.Query(index, &siz); - utils::Assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index"); - utils::Assert(siz == type_nbytes*count, "cache size stored expected to be same as requested"); - utils::Assert(siz > 0, "cache size should be greater than 0"); + _assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index"); + _assert(siz == type_nbytes*count, "cache size stored expected to be same as requested"); + _assert(siz > 0, "cache size should be greater than 0"); std::memcpy(buf, temp, type_nbytes*count); return 0; } @@ -317,7 +327,7 @@ int AllreduceRobust::LoadCheckPoint(Serializable *global_model, local_rptr[local_chkpt_version][1]); local_model->Load(&fs); } else { - utils::Assert(nlocal == 0, "[%d] local model inconsistent, nlocal=%d", rank, nlocal); + _assert(nlocal == 0, "[%d] local model inconsistent, nlocal=%d", rank, nlocal); } } // reset result buffer @@ -327,14 +337,14 @@ int AllreduceRobust::LoadCheckPoint(Serializable *global_model, if (global_checkpoint.length() == 0) { version_number = 0; } else { - utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0, + _assert(fs.Read(&version_number, sizeof(version_number)) != 0, "read in version number"); global_model->Load(&fs); - utils::Assert(local_model == NULL || nlocal == num_local_replica + 1, + _assert(local_model == NULL || nlocal == num_local_replica + 1, "local model inconsistent, nlocal=%d", nlocal); } // run another phase of check ack, if recovered from data - utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, + _assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp, cur_cache_seq), "check ack must return true"); if (!RecoverExec(NULL, 0, ActionSummary::kLoadBootstrapCache, seq_counter, cur_cache_seq)) { @@ -433,7 +443,7 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model, local_chkpt_version = !local_chkpt_version; } // execute checkpoint, note: when checkpoint existing, load will not happen - utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, + _assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp, cur_cache_seq), "check point must return true"); // this is the critical region where we will change all the stored models @@ -460,7 +470,7 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model, // reset result buffer, mark boostrap phase complete resbuf.Clear(); seq_counter = 0; // execute check ack step, load happens here - utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, + _assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp, cur_cache_seq), "check ack must return true"); delta = utils::GetTime() - start; @@ -533,7 +543,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { if (all_links[i].size_read == 0) { int atmark = all_links[i].sock.AtMark(); if (atmark < 0) { - utils::Assert(all_links[i].sock.BadSocket(), "must already gone bad"); + _assert(all_links[i].sock.BadSocket(), "must already gone bad"); } else if (atmark > 0) { all_links[i].size_read = 1; } else { @@ -555,10 +565,10 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { if (len == 0) { all_links[i].sock.Close(); continue; } else if (len > 0) { - utils::Assert(oob_mark == kResetMark, "wrong oob msg"); - utils::Assert(all_links[i].sock.AtMark() != 1, "should already read past mark"); + _assert(oob_mark == kResetMark, "wrong oob msg"); + _assert(all_links[i].sock.AtMark() != 1, "should already read past mark"); } else { - utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); + _assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); } // send out ack char ack = kResetAck; @@ -579,9 +589,9 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { if (len == 0) { all_links[i].sock.Close(); continue; } else if (len > 0) { - utils::Assert(ack == kResetAck, "wrong Ack MSG"); + _assert(ack == kResetAck, "wrong Ack MSG"); } else { - utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); + _assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); } // set back to nonblock mode all_links[i].sock.SetNonBlock(true); @@ -600,14 +610,44 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { * \return true if err_type is kSuccess, false otherwise */ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) { + shutdown_timeout = err_type == kSuccess; if (err_type == kSuccess) return true; - utils::Assert(err_link != NULL, "must know the error source"); - recover_counter += 1; + _assert(err_link != NULL, "must know the error link"); + recover_counter += 1; + // async launch timeout task if enable_rabit_timeout is set + if (rabit_timeout && !rabit_timeout_task.valid()) { + utils::Printf("[EXPERIMENTAL] timeout thread expires in %d second(s)\n", timeout_sec); + rabit_timeout_task = std::async(std::launch::async, [=]() { + if (rabit_debug) { + utils::Printf("[%d] timeout thread %ld starts\n", rank, + std::this_thread::get_id()); + } + int time = 0; + // check if rabit recovered every 100ms + while (time++ < 10 * timeout_sec) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + if (shutdown_timeout.load()) { + if (rabit_debug) { + utils::Printf("[%d] timeout task thread %ld exits\n", + rank, std::this_thread::get_id()); + } + return true; + } + } + // print on tracker to help debuging + TrackerPrint("[ERROR] rank " + std::to_string(rank) + "@"+ + host_uri + ":" +std::to_string(port) + " timeout\n"); + _error("[%d] exit due to time out %d s\n", rank, timeout_sec); + return false; + }); + } // simple way, shutdown all links for (size_t i = 0; i < all_links.size(); ++i) { if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close(); } + // smooth out traffic to tracker + std::this_thread::sleep_for(std::chrono::milliseconds(10*rank)); ReConnectLinks("recover"); return false; } @@ -724,8 +764,8 @@ AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role, // set p_req_in (*p_req_in)[i] = (req_in[i] != 0); if (req_out[i] != 0) { - utils::Assert(req_in[i] == 0, "cannot get and receive request"); - utils::Assert(static_cast(i) == best_link, "request result inconsistent"); + _assert(req_in[i] == 0, "cannot get and receive request"); + _assert(static_cast(i) == best_link, "request result inconsistent"); } } *p_recvlink = best_link; @@ -755,20 +795,20 @@ AllreduceRobust::TryRecoverData(RecoverType role, RefLinkVector &links = tree_links; // no need to run recovery for zero size messages if (links.size() == 0 || size == 0) return kSuccess; - utils::Assert(req_in.size() == links.size(), "TryRecoverData"); + _assert(req_in.size() == links.size(), "TryRecoverData"); const int nlink = static_cast(links.size()); { bool req_data = role == kRequestData; for (int i = 0; i < nlink; ++i) { if (req_in[i]) { - utils::Assert(i != recv_link, "TryDecideRouting"); + _assert(i != recv_link, "TryDecideRouting"); req_data = true; } } // do not need to provide data or receive data, directly exit if (!req_data) return kSuccess; } - utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active"); + _assert(recv_link >= 0 || role == kHaveData, "recv_link must be active"); if (role == kPassData) { links[recv_link].InitBuffer(1, size, reduce_buffer_size); } @@ -835,7 +875,7 @@ AllreduceRobust::TryRecoverData(RecoverType role, for (int i = 0; i < nlink; ++i) { if (req_in[i]) min_write = std::min(links[i].size_write, min_write); } - utils::Assert(min_write <= links[pid].size_read, "boundary check"); + _assert(min_write <= links[pid].size_read, "boundary check"); ReturnType ret = links[pid].ReadToRingBuffer(min_write, size); if (ret != kSuccess) { return ReportError(&links[pid], ret); @@ -869,7 +909,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryRestoreCache(bool requester, const int min_seq, const int max_seq) { // clear requester and rebuild from those with most cache entries if (requester) { - utils::Assert(cur_cache_seq <= max_seq, "requester is expected to have fewer cache entries"); + _assert(cur_cache_seq <= max_seq, "requester is expected to have fewer cache entries"); cachebuf.Clear(); lookupbuf.Clear(); cur_cache_seq = 0; @@ -998,7 +1038,7 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re int new_version = !local_chkpt_version; int nlocal = std::max(static_cast(local_rptr[new_version].size()) - 1, 0); // if we goes to this place, use must have already setup the state once - utils::Assert(nlocal == 1 || nlocal == num_local_replica + 1, + _assert(nlocal == 1 || nlocal == num_local_replica + 1, "TryGetResult::Checkpoint"); return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]); } @@ -1048,13 +1088,13 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno, // kLoadBootstrapCache should be treated similar as allreduce // when loadcheck/check/checkack runs in other nodes if (flag != 0 && flag != ActionSummary::kLoadBootstrapCache) { - utils::Assert(seqno == ActionSummary::kSpecialOp, "must only set seqno for normal operations"); + _assert(seqno == ActionSummary::kSpecialOp, "must only set seqno for normal operations"); } std::string msg = std::string(caller) + " pass negative seqno " + std::to_string(seqno) + " flag " + std::to_string(flag) + " version " + std::to_string(version_number); - utils::Assert(seqno >=0, msg.c_str()); + _assert(seqno >=0, msg.c_str()); ActionSummary req(flag, flag, seqno, cache_seqno); @@ -1068,7 +1108,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno, if (act.check_ack()) { if (act.check_point()) { // if we also have check_point, do check point first - utils::Assert(!act.diff_seq(), + _assert(!act.diff_seq(), "check ack & check pt cannot occur together with normal ops"); // if we requested checkpoint, we are free to go if (req.check_point()) return true; @@ -1087,7 +1127,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno, } else { if (act.check_point()) { if (act.diff_seq()) { - utils::Assert(act.seqno() != ActionSummary::kSpecialOp, "min seq bug"); + _assert(act.seqno() != ActionSummary::kSpecialOp, "min seq bug"); // print checkpoint consensus flag if user turn on debug if (rabit_debug) { req.print_flags(rank, "checkpoint req"); @@ -1112,16 +1152,16 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno, if (!act.load_cache()) { if (act.seqno() > 0) { if (!requester) { - utils::Assert(req.check_point(), "checkpoint node should be KHaveData role"); + _assert(req.check_point(), "checkpoint node should be KHaveData role"); buf = resbuf.Query(act.seqno(), &size); - utils::Assert(buf != NULL, "buf should have data from resbuf"); - utils::Assert(size > 0, "buf size should be greater than 0"); + _assert(buf != NULL, "buf should have data from resbuf"); + _assert(size > 0, "buf size should be greater than 0"); } if (!CheckAndRecover(TryGetResult(buf, size, act.seqno(), requester))) continue; } } else { // cache seq no should be smaller than kSpecialOp - utils::Assert(act.seqno(SeqType::kCache) != ActionSummary::kSpecialOp, + _assert(act.seqno(SeqType::kCache) != ActionSummary::kSpecialOp, "checkpoint with kSpecialOp"); int max_cache_seq = cur_cache_seq; if (TryAllreduce(&max_cache_seq, sizeof(max_cache_seq), 1, @@ -1153,11 +1193,11 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno, act.print_flags(rank, "loadcache act"); } // load cache should not running in parralel with other states - utils::Assert(!act.load_check(), + _assert(!act.load_check(), "load cache state expect no nodes doing load checkpoint"); - utils::Assert(!act.check_point() , + _assert(!act.check_point() , "load cache state expect no nodes doing checkpoint"); - utils::Assert(!act.check_ack(), + _assert(!act.check_ack(), "load cache state expect no nodes doing checkpoint ack"); // if all nodes are requester in load cache, skip @@ -1176,10 +1216,10 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno, } // assert no req with load cache set goes into seq catch up - utils::Assert(!req.load_cache(), "load cache not interacte with rest states"); + _assert(!req.load_cache(), "load cache not interacte with rest states"); // no special flags, no checkpoint, check ack, load_check - utils::Assert(act.seqno() != ActionSummary::kSpecialOp, "min seq bug"); + _assert(act.seqno() != ActionSummary::kSpecialOp, "min seq bug"); if (act.diff_seq()) { bool requester = req.seqno() == act.seqno(); if (!CheckAndRecover(TryGetResult(buf, size, act.seqno(), requester))) continue; @@ -1194,7 +1234,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno, // something is still incomplete try next round } } - utils::Assert(false, "RecoverExec: should not reach here"); + _assert(false, "RecoverExec: should not reach here"); return true; } /*! @@ -1222,13 +1262,13 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, std::string &chkpt = *p_local_chkpt; if (rptr.size() == 0) { rptr.push_back(0); - utils::Assert(chkpt.length() == 0, "local chkpt space inconsistent"); + _assert(chkpt.length() == 0, "local chkpt space inconsistent"); } const int n = num_local_replica; { // backward passing, passing state in backward direction of the ring const int nlocal = static_cast(rptr.size() - 1); - utils::Assert(nlocal <= n + 1, "invalid local replica"); + _assert(nlocal <= n + 1, "invalid local replica"); std::vector msg_back(n + 1); msg_back[0] = nlocal; // backward passing one hop the request @@ -1282,7 +1322,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, { // forward passing, passing state in forward direction of the ring const int nlocal = static_cast(rptr.size() - 1); - utils::Assert(nlocal <= n + 1, "invalid local replica"); + _assert(nlocal <= n + 1, "invalid local replica"); std::vector msg_forward(n + 1); msg_forward[0] = nlocal; // backward passing one hop the request @@ -1367,7 +1407,7 @@ AllreduceRobust::TryCheckinLocalState(std::vector *p_local_rptr, if (num_local_replica == 0) return kSuccess; std::vector &rptr = *p_local_rptr; std::string &chkpt = *p_local_chkpt; - utils::Assert(rptr.size() == 2, + _assert(rptr.size() == 2, "TryCheckinLocalState must have exactly 1 state"); const int n = num_local_replica; std::vector sizes(n + 1); @@ -1423,10 +1463,10 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_, LinkRecord *read_link, LinkRecord *write_link) { if (read_link == NULL || write_link == NULL || read_end == 0) return kSuccess; - utils::Assert(write_end <= read_end, + _assert(write_end <= read_end, "RingPassing: boundary check1"); - utils::Assert(read_ptr <= read_end, "RingPassing: boundary check2"); - utils::Assert(write_ptr <= write_end, "RingPassing: boundary check3"); + _assert(read_ptr <= read_end, "RingPassing: boundary check2"); + _assert(write_ptr <= write_end, "RingPassing: boundary check3"); // take reference LinkRecord &prev = *read_link, &next = *write_link; // send recv buffer diff --git a/src/allreduce_robust.h b/src/allreduce_robust.h index 7704a31c..0e2fed40 100644 --- a/src/allreduce_robust.h +++ b/src/allreduce_robust.h @@ -10,6 +10,7 @@ */ #ifndef RABIT_ALLREDUCE_ROBUST_H_ #define RABIT_ALLREDUCE_ROBUST_H_ +#include #include #include #include @@ -632,6 +633,14 @@ o * the input state must exactly one saved state(local state of current node) int local_chkpt_version; // if checkpoint were loaded, used to distinguish results boostrap cache from seqno cache bool checkpoint_loaded; + // sidecar executing timeout task + std::future rabit_timeout_task; + // flag to shutdown rabit_timeout_task before timeout + std::atomic shutdown_timeout{false}; + // error handler + void (* _error)(const char *fmt, ...) = utils::Error; + // assert handler + void (* _assert)(bool exp, const char *fmt, ...) = utils::Assert; }; } // namespace engine } // namespace rabit diff --git a/test/cpp/CMakeLists.txt b/test/cpp/CMakeLists.txt index 987f4c01..e7c15bc3 100644 --- a/test/cpp/CMakeLists.txt +++ b/test/cpp/CMakeLists.txt @@ -2,14 +2,13 @@ find_package(GTest REQUIRED) add_executable( unit_tests - allreduce_base_test.cpp - allreduce_mock_test.cpp + allreduce_base_test.cc allreduce_robust_test.cc allreduce_mock_test.cc test_main.cpp) target_link_libraries( - unit_tests + unit_tests PRIVATE GTest::GTest GTest::Main - rabit_base rabit_mock) + rabit_base rabit_mock rabit) target_include_directories(unit_tests PUBLIC "$" diff --git a/test/cpp/allreduce_base_test.cpp b/test/cpp/allreduce_base_test.cc similarity index 100% rename from test/cpp/allreduce_base_test.cpp rename to test/cpp/allreduce_base_test.cc diff --git a/test/cpp/allreduce_mock_test.cpp b/test/cpp/allreduce_mock_test.cc similarity index 100% rename from test/cpp/allreduce_mock_test.cpp rename to test/cpp/allreduce_mock_test.cc diff --git a/test/cpp/allreduce_robust_test.cc b/test/cpp/allreduce_robust_test.cc new file mode 100644 index 00000000..c86f7627 --- /dev/null +++ b/test/cpp/allreduce_robust_test.cc @@ -0,0 +1,233 @@ +#define RABIT_CXXTESTDEFS_H +#include + +#include +#include +#include +#include "../../src/allreduce_robust.h" + +inline void mockerr(const char *fmt, ...) {EXPECT_STRCASEEQ(fmt, "[%d] exit due to time out %d s\n");} +inline void mockassert(bool val, const char *fmt, ...) {} +rabit::engine::AllreduceRobust::ReturnType err_type(rabit::engine::AllreduceRobust::ReturnTypeEnum::kSockError); +rabit::engine::AllreduceRobust::ReturnType succ_type(rabit::engine::AllreduceRobust::ReturnTypeEnum::kSuccess); + +TEST(allreduce_robust, sync_error_timeout) +{ + rabit::engine::AllreduceRobust m; + + std::string rabit_timeout = "rabit_timeout=1"; + char cmd[rabit_timeout.size()+1]; + std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd); + cmd[rabit_timeout.size()] = '\0'; + + std::string rabit_timeout_sec = "rabit_timeout_sec=1"; + char cmd1[rabit_timeout_sec.size()+1]; + std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1); + cmd1[rabit_timeout_sec.size()] = '\0'; + + char* argv[] = {cmd,cmd1}; + m.Init(2, argv); + m.rank = 0; + m.rabit_bootstrap_cache = 1; + m._error = mockerr; + m._assert = mockassert; + EXPECT_EQ(m.CheckAndRecover(err_type), false); + std::this_thread::sleep_for(std::chrono::milliseconds(1500)); + EXPECT_EQ(m.rabit_timeout_task.get(), false); +} + +TEST(allreduce_robust, sync_error_reset) +{ + rabit::engine::AllreduceRobust m; + + std::string rabit_timeout = "rabit_timeout=1"; + char cmd[rabit_timeout.size()+1]; + std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd); + cmd[rabit_timeout.size()] = '\0'; + + std::string rabit_timeout_sec = "rabit_timeout_sec=1"; + char cmd1[rabit_timeout_sec.size()+1]; + std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1); + cmd1[rabit_timeout_sec.size()] = '\0'; + + std::string rabit_debug = "rabit_debug=1"; + char cmd2[rabit_debug.size()+1]; + std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2); + cmd2[rabit_debug.size()] = '\0'; + + char* argv[] = {cmd, cmd1,cmd2}; + m.Init(3, argv); + m.rank = 0; + m._assert = mockassert; + EXPECT_EQ(m.CheckAndRecover(err_type), false); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + EXPECT_EQ(m.CheckAndRecover(succ_type), true); + EXPECT_EQ(m.rabit_timeout_task.get(), true); + m.Shutdown(); +} + +TEST(allreduce_robust, sync_success_error_timeout) +{ + rabit::engine::AllreduceRobust m; + + std::string rabit_timeout = "rabit_timeout=1"; + char cmd[rabit_timeout.size()+1]; + std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd); + cmd[rabit_timeout.size()] = '\0'; + + std::string rabit_timeout_sec = "rabit_timeout_sec=1"; + char cmd1[rabit_timeout_sec.size()+1]; + std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1); + cmd1[rabit_timeout_sec.size()] = '\0'; + + std::string rabit_debug = "rabit_debug=1"; + char cmd2[rabit_debug.size()+1]; + std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2); + cmd2[rabit_debug.size()] = '\0'; + + char* argv[] = {cmd, cmd1,cmd2}; + m.Init(3, argv); + m.rank = 0; + m.rabit_bootstrap_cache = 1; + m._assert = mockassert; + m._error = mockerr; + EXPECT_EQ(m.CheckAndRecover(succ_type), true); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + EXPECT_EQ(m.CheckAndRecover(err_type), false); + std::this_thread::sleep_for(std::chrono::milliseconds(1500)); + EXPECT_EQ(m.rabit_timeout_task.get(), false); +} + +TEST(allreduce_robust, sync_success_error_success) +{ + rabit::engine::AllreduceRobust m; + + std::string rabit_timeout = "rabit_timeout=1"; + char cmd[rabit_timeout.size()+1]; + std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd); + cmd[rabit_timeout.size()] = '\0'; + + std::string rabit_timeout_sec = "rabit_timeout_sec=1"; + char cmd1[rabit_timeout_sec.size()+1]; + std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1); + cmd1[rabit_timeout_sec.size()] = '\0'; + + std::string rabit_debug = "rabit_debug=1"; + char cmd2[rabit_debug.size()+1]; + std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2); + cmd2[rabit_debug.size()] = '\0'; + + char* argv[] = {cmd, cmd1,cmd2}; + m.Init(3, argv); + m.rank = 0; + m.rabit_bootstrap_cache = 1; + m._assert = mockassert; + EXPECT_EQ(m.CheckAndRecover(succ_type), true); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + EXPECT_EQ(m.CheckAndRecover(err_type), false); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + EXPECT_EQ(m.CheckAndRecover(succ_type), true); + std::this_thread::sleep_for(std::chrono::milliseconds(1100)); + EXPECT_EQ(m.rabit_timeout_task.get(), true); + m.Shutdown(); +} + +TEST(allreduce_robust, sync_error_no_reset_timeout) +{ + rabit::engine::AllreduceRobust m; + + std::string rabit_timeout = "rabit_timeout=1"; + char cmd[rabit_timeout.size()+1]; + std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd); + cmd[rabit_timeout.size()] = '\0'; + + std::string rabit_timeout_sec = "rabit_timeout_sec=1"; + char cmd1[rabit_timeout_sec.size()+1]; + std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1); + cmd1[rabit_timeout_sec.size()] = '\0'; + + std::string rabit_debug = "rabit_debug=1"; + char cmd2[rabit_debug.size()+1]; + std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2); + cmd2[rabit_debug.size()] = '\0'; + + char* argv[] = {cmd, cmd1,cmd2}; + m.Init(3, argv); + m.rank = 0; + m.rabit_bootstrap_cache = 1; + m._assert = mockassert; + m._error = mockerr; + auto start = std::chrono::system_clock::now(); + + EXPECT_EQ(m.CheckAndRecover(err_type), false); + std::this_thread::sleep_for(std::chrono::milliseconds(1100)); + + EXPECT_EQ(m.CheckAndRecover(err_type), false); + + m.rabit_timeout_task.wait(); + auto end = std::chrono::system_clock::now(); + std::chrono::duration diff = end-start; + + EXPECT_EQ(m.rabit_timeout_task.get(), false); + // expect second error don't overwrite/reset timeout task + EXPECT_LT(diff.count(), 2); +} + +TEST(allreduce_robust, no_timeout_shut_down) +{ + rabit::engine::AllreduceRobust m; + + std::string rabit_timeout = "rabit_timeout=1"; + char cmd[rabit_timeout.size()+1]; + std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd); + cmd[rabit_timeout.size()] = '\0'; + + std::string rabit_timeout_sec = "rabit_timeout_sec=1"; + char cmd1[rabit_timeout_sec.size()+1]; + std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1); + cmd1[rabit_timeout_sec.size()] = '\0'; + + std::string rabit_debug = "rabit_debug=1"; + char cmd2[rabit_debug.size()+1]; + std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2); + cmd2[rabit_debug.size()] = '\0'; + + char* argv[] = {cmd, cmd1,cmd2}; + m.Init(3, argv); + m.rank = 0; + + EXPECT_EQ(m.CheckAndRecover(succ_type), true); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + m.Shutdown(); +} + +TEST(allreduce_robust, shut_down_before_timeout) +{ + rabit::engine::AllreduceRobust m; + + std::string rabit_timeout = "rabit_timeout=1"; + char cmd[rabit_timeout.size()+1]; + std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd); + cmd[rabit_timeout.size()] = '\0'; + + std::string rabit_timeout_sec = "rabit_timeout_sec=1"; + char cmd1[rabit_timeout_sec.size()+1]; + std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1); + cmd1[rabit_timeout_sec.size()] = '\0'; + + std::string rabit_debug = "rabit_debug=1"; + char cmd2[rabit_debug.size()+1]; + std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2); + cmd2[rabit_debug.size()] = '\0'; + + char* argv[] = {cmd, cmd1,cmd2}; + m.Init(3, argv); + m.rank = 0; + rabit::engine::AllreduceRobust::LinkRecord a; + m.err_link = &a; + + EXPECT_EQ(m.CheckAndRecover(err_type), false); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + m.Shutdown(); +} \ No newline at end of file diff --git a/test/cpp/test_main.cpp b/test/cpp/test_main.cpp index 08fb8390..6eb025ac 100644 --- a/test/cpp/test_main.cpp +++ b/test/cpp/test_main.cpp @@ -3,5 +3,6 @@ int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); + ::testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); } diff --git a/test/test.mk b/test/test.mk index e6706daf..1028ec50 100644 --- a/test/test.mk +++ b/test/test.mk @@ -13,7 +13,7 @@ all: model_recover_10_10k model_recover_10_10k_die_same model_recover_10_10k_di # this experiment test recovery with actually process exit, use keepalive to keep program alive model_recover_10_10k: - $(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 rabit_bootstrap_cache=-1 rabit_debug=1 rabit_reduce_ring_mincount=1 + $(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 rabit_bootstrap_cache=true rabit_debug=true rabit_reduce_ring_mincount=1 rabit_timeout=true rabit_timeout_sec=5 model_recover_10_10k_die_same: $(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 rabit_bootstrap_cache=1