Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Drop single point model recovery. #6112

Closed
wants to merge 16 commits into from
Closed
14 changes: 0 additions & 14 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -321,20 +321,6 @@ def TestPythonGPU(args) {
}
}

def TestCppRabit() {
node(nodeReq) {
unstash name: 'xgboost_rabit_tests'
unstash name: 'srcs'
echo "Test C++, rabit mock on"
def container_type = "cpu"
def docker_binary = "docker"
sh """
${dockerRun} ${container_type} ${docker_binary} tests/ci_build/runxgb.sh xgboost tests/ci_build/approx.conf.in
"""
deleteDir()
}
}

def TestCppGPU(args) {
def nodeReq = 'linux && mgpu'
def artifact_cuda_version = (args.artifact_cuda_version) ?: ref_cuda_ver
Expand Down
2 changes: 1 addition & 1 deletion R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ PKG_LIBS = @OPENMP_CXXFLAGS@ @OPENMP_LIB@ @ENDIAN_FLAG@ @BACKTRACE_LIB@ -pthread
OBJECTS= ./xgboost_R.o ./xgboost_custom.o ./xgboost_assert.o ./init.o \
$(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o \
$(PKGROOT)/rabit/src/engine.o $(PKGROOT)/rabit/src/c_api.o \
$(PKGROOT)/rabit/src/allreduce_base.o $(PKGROOT)/rabit/src/allreduce_robust.o
$(PKGROOT)/rabit/src/allreduce_base.o
2 changes: 1 addition & 1 deletion R-package/src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ PKG_LIBS = $(SHLIB_OPENMP_CXXFLAGS) $(SHLIB_PTHREAD_FLAGS)
OBJECTS= ./xgboost_R.o ./xgboost_custom.o ./xgboost_assert.o ./init.o \
$(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o \
$(PKGROOT)/rabit/src/engine.o $(PKGROOT)/rabit/src/c_api.o \
$(PKGROOT)/rabit/src/allreduce_base.o $(PKGROOT)/rabit/src/allreduce_robust.o
$(PKGROOT)/rabit/src/allreduce_base.o

$(OBJECTS) : xgblib
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ object XGBoost extends Serializable {
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
val sc = trainingData.sparkContext
val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet,
hasGroup, xgbExecParams.numWorkers)
Expand All @@ -595,6 +596,8 @@ object XGBoost extends Serializable {
xgbExecParams.timeoutRequestWorkers,
xgbExecParams.numWorkers,
xgbExecParams.killSparkContextOnWorkerFailure)

tracker.getWorkerEnvs().putAll(xgbRabitParams)
val rabitEnv = tracker.getWorkerEnvs
val boostersAndMetrics = if (hasGroup) {
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster,
Expand Down
5 changes: 3 additions & 2 deletions rabit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ cmake_minimum_required(VERSION 3.3)

find_package(Threads REQUIRED)

add_library(rabit src/allreduce_base.cc src/allreduce_robust.cc src/engine.cc src/c_api.cc)
add_library(rabit_mock_static src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc)
add_library(rabit src/allreduce_base.cc src/engine.cc src/c_api.cc)
add_library(rabit_mock_static src/allreduce_base.cc src/engine_mock.cc src/c_api.cc)

target_link_libraries(rabit Threads::Threads dmlc)
target_link_libraries(rabit_mock_static Threads::Threads dmlc)

Expand Down
23 changes: 6 additions & 17 deletions rabit/include/rabit/internal/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <string>
#include <cstring>
#include <vector>
#include <chrono>
#include <unordered_map>
#include "utils.h"

Expand Down Expand Up @@ -95,18 +96,18 @@ namespace utils {
static constexpr int kInvalidSocket = -1;

template <typename PollFD>
int PollImpl(PollFD *pfd, int nfds, int timeout) {
int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) {
#if defined(_WIN32)

#if IS_MINGW()
MingWError();
return -1;
#else
return WSAPoll(pfd, nfds, timeout);
return WSAPoll(pfd, nfds, std::chrono::milliseconds(timeout).count());
#endif // IS_MINGW()

#else
return poll(pfd, nfds, timeout);
return poll(pfd, nfds, std::chrono::milliseconds(timeout).count());
#endif // IS_MINGW()
}

Expand Down Expand Up @@ -616,32 +617,20 @@ struct PollHelper {
const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0);
}
/*!
* \brief wait for exception event on a single descriptor
* \param fd the file descriptor to wait the event for
* \param timeout the timeout counter, can be negative, which means wait until the event happen
* \return 1 if success, 0 if timeout, and -1 if error occurs
*/
inline static int WaitExcept(SOCKET fd, long timeout = -1) { // NOLINT(*)
pollfd pfd;
pfd.fd = fd;
pfd.events = POLLPRI;
return PollImpl(&pfd, 1, timeout);
}

/*!
* \brief peform poll on the set defined, read, write, exception
* \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
* \return
*/
inline void Poll(long timeout = -1) { // NOLINT(*)
inline void Poll(std::chrono::seconds timeout) { // NOLINT(*)
std::vector<pollfd> fdset;
fdset.reserve(fds.size());
for (auto kv : fds) {
fdset.push_back(kv.second);
}
int ret = PollImpl(fdset.data(), fdset.size(), timeout);
if (ret == -1) {
if (ret <= 0) {
Socket::Error("Poll");
} else {
for (auto& pfd : fdset) {
Expand Down
31 changes: 0 additions & 31 deletions rabit/src/CMakeLists.txt

This file was deleted.

6 changes: 0 additions & 6 deletions rabit/src/README.md

This file was deleted.

33 changes: 24 additions & 9 deletions rabit/src/allreduce_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#define NOMINMAX
#include "rabit/base.h"
#include "rabit/internal/rabit-inl.h"
#include "allreduce_base.h"
#include <rabit/base.h>

#ifndef _WIN32
#include <netinet/tcp.h>
Expand Down Expand Up @@ -208,8 +209,8 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
rabit_timeout = utils::StringToBool(val);
}
if (!strcmp(name, "rabit_timeout_sec")) {
timeout_sec = atoi(val);
utils::Assert(timeout_sec >= 0, "rabit_timeout_sec should be non negative second");
timeout_sec = std::chrono::seconds(atoi(val));
utils::Assert(timeout_sec.count() >= 0, "rabit_timeout_sec should be non negative second");
}
if (!strcmp(name, "rabit_enable_tcp_no_delay")) {
if (!strcmp(val, "true")) {
Expand Down Expand Up @@ -549,11 +550,12 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
// finish runing allreduce
if (finished) break;
// select must return
watcher.Poll();
watcher.Poll(timeout_sec);
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (watcher.CheckExcept(links[i].sock)) {
std::cout << __LINE__ << "kGetExcept" << std::endl;
return ReportError(&links[i], kGetExcept);
}
}
Expand All @@ -566,6 +568,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
&& links[i].size_read - size_up_reduce < eachreduce) {
ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
if (ret != kSuccess) {
std::cout << __LINE__ << ReportError(&links[i], ret).Message() << std::endl;
return ReportError(&links[i], ret);
}
}
Expand Down Expand Up @@ -623,6 +626,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
} else {
ReturnType ret = Errno2Return();
if (ret != kSuccess) {
std::cout << __LINE__ << ReportError(&links[parent_index], ret).Message() << std::endl;
return ReportError(&links[parent_index], ret);
}
}
Expand All @@ -639,6 +643,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,

if (len == 0) {
links[parent_index].sock.Close();
std::cout << __LINE__ << ReportError(&links[parent_index], kRecvZeroLen).Message() << std::endl;
return ReportError(&links[parent_index], kRecvZeroLen);
}
if (len != -1) {
Expand All @@ -655,6 +660,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
} else {
ReturnType ret = Errno2Return();
if (ret != kSuccess) {
std::cout << __LINE__ << ReportError(&links[parent_index], ret).Message() << std::endl;
return ReportError(&links[parent_index], ret);
}
}
Expand All @@ -669,6 +675,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
if (i != parent_index && links[i].size_write < size_down_in) {
ReturnType ret = links[i].WriteFromArray(sendrecvbuf, size_down_in);
if (ret != kSuccess) {
std::cout << __LINE__ << ReportError(&links[parent_index], ret).Message() << std::endl;
return ReportError(&links[i], ret);
}
}
Expand Down Expand Up @@ -729,7 +736,7 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
// finish running
if (finished) break;
// select
watcher.Poll();
watcher.Poll(timeout_sec);
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
Expand Down Expand Up @@ -819,7 +826,7 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
finished = false;
}
if (finished) break;
watcher.Poll();
watcher.Poll(timeout_sec);
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
size_t size = stop_read - read_ptr;
size_t start = read_ptr % total_size;
Expand All @@ -831,7 +838,11 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
read_ptr += static_cast<size_t>(len);
} else {
ReturnType ret = Errno2Return();
if (ret != kSuccess) return ReportError(&next, ret);
if (ret != kSuccess) {
auto err = ReportError(&next, ret);
std::cout << __LINE__ << ": " << err.Message() << std::endl;
return err;
}
}
}
if (write_ptr < read_ptr && write_ptr != stop_write) {
Expand All @@ -845,7 +856,11 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
write_ptr += static_cast<size_t>(len);
} else {
ReturnType ret = Errno2Return();
if (ret != kSuccess) return ReportError(&prev, ret);
if (ret != kSuccess) {
auto err = ReportError(&prev, ret);
std::cout << __LINE__ << ": " << err.Message() << std::endl;
return err;
}
}
}
}
Expand Down Expand Up @@ -913,7 +928,7 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
finished = false;
}
if (finished) break;
watcher.Poll();
watcher.Poll(timeout_sec);
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
if (ret != kSuccess) {
Expand Down
35 changes: 27 additions & 8 deletions rabit/src/allreduce_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#ifndef RABIT_ALLREDUCE_BASE_H_
#define RABIT_ALLREDUCE_BASE_H_

#include <functional>
#include <future>
#include <vector>
#include <string>
#include <algorithm>
Expand All @@ -35,6 +37,7 @@ class Datatype {
}
namespace rabit {
namespace engine {

/*! \brief implementation of basic Allreduce engine */
class AllreduceBase : public IEngine {
public:
Expand Down Expand Up @@ -103,9 +106,11 @@ class AllreduceBase : public IEngine {
size_t slice_end, size_t size_prev_slice,
const char *_file = _FILE, const int _line = _LINE,
const char *_caller = _CALLER) override {
if (world_size == 1 || world_size == -1) return;
utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size,
slice_begin, slice_end, size_prev_slice) == kSuccess,
if (world_size == 1 || world_size == -1) {
return;
}
utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size, slice_begin,
slice_end, size_prev_slice) == kSuccess,
"AllgatherRing failed");
}
/*!
Expand All @@ -130,8 +135,8 @@ class AllreduceBase : public IEngine {
const char *_caller = _CALLER) override {
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
if (world_size == 1 || world_size == -1) return;
utils::Assert(TryAllreduce(sendrecvbuf_,
type_nbytes, count, reducer) == kSuccess,
utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) ==
kSuccess,
"Allreduce failed");
}
/*!
Expand Down Expand Up @@ -276,6 +281,20 @@ class AllreduceBase : public IEngine {
inline bool operator!=(const ReturnTypeEnum &v) const {
return value != v;
}
std::string Message() const {
switch (value) {
case kSuccess:
return "kSuccess";
case kConnReset:
return "kConnReset";
case kRecvZeroLen:
return "kRecvZeroLen";
case kSockError:
return "kSockError";
case kGetExcept:
return "kGetExcept";
}
}
};
/*! \brief translate errno to return type */
inline static ReturnType Errno2Return() {
Expand Down Expand Up @@ -518,9 +537,9 @@ class AllreduceBase : public IEngine {
//---- data structure related to model ----
// call sequence counter, records how many calls we made so far
// from last call to CheckPoint, LoadCheckPoint
int seq_counter; // NOLINT
int seq_counter{0}; // NOLINT
// version number of model
int version_number; // NOLINT
int version_number {0}; // NOLINT
// whether the job is running in hadoop
bool hadoop_mode; // NOLINT
//---- local data related to link ----
Expand Down Expand Up @@ -571,7 +590,7 @@ class AllreduceBase : public IEngine {
// enable detailed logging
bool rabit_debug = false; // NOLINT
// by default, if rabit worker not recover in half an hour exit
int timeout_sec = 1800; // NOLINT
std::chrono::seconds timeout_sec{std::chrono::seconds{1800}}; // NOLINT
// flag to enable rabit_timeout
bool rabit_timeout = false; // NOLINT
// Enable TCP node delay
Expand Down
Loading