Skip to content

Commit

Permalink
Revamp the rabit implementation. (#10112)
Browse files Browse the repository at this point in the history
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
  • Loading branch information
trivialfis authored May 20, 2024
1 parent ba9b4cb commit a5a5810
Show file tree
Hide file tree
Showing 195 changed files with 2,750 additions and 9,216 deletions.
4 changes: 0 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ option(USE_DMLC_GTEST "Use google tests bundled with dmlc-core submodule" OFF)
option(USE_DEVICE_DEBUG "Generate CUDA device debug info." OFF)
option(USE_NVTX "Build with cuda profiling annotations. Developers only." OFF)
set(NVTX_HEADER_DIR "" CACHE PATH "Path to the stand-alone nvtx header")
option(RABIT_MOCK "Build rabit with mock" OFF)
option(HIDE_CXX_SYMBOLS "Build shared library and hide all C++ symbols" OFF)
option(KEEP_BUILD_ARTIFACTS_IN_BINARY_DIR "Output build artifacts in CMake binary dir" OFF)
## CUDA
Expand Down Expand Up @@ -282,9 +281,6 @@ if(MSVC)
endif()
endif()

# rabit
add_subdirectory(rabit)

# core xgboost
add_subdirectory(${xgboost_SOURCE_DIR}/src)
target_link_libraries(objxgboost PUBLIC dmlc)
Expand Down
8 changes: 1 addition & 7 deletions R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/comm_group.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/communicator.o \
$(PKGROOT)/src/collective/in_memory_communicator.o \
$(PKGROOT)/src/collective/in_memory_handler.o \
$(PKGROOT)/src/collective/loop.o \
$(PKGROOT)/src/collective/socket.o \
Expand All @@ -134,7 +131,4 @@ OBJECTS= \
$(PKGROOT)/src/common/version.o \
$(PKGROOT)/src/c_api/c_api.o \
$(PKGROOT)/src/c_api/c_api_error.o \
$(PKGROOT)/amalgamation/dmlc-minimum0.o \
$(PKGROOT)/rabit/src/engine.o \
$(PKGROOT)/rabit/src/rabit_c_api.o \
$(PKGROOT)/rabit/src/allreduce_base.o
$(PKGROOT)/amalgamation/dmlc-minimum0.o
8 changes: 1 addition & 7 deletions R-package/src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/comm_group.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/communicator.o \
$(PKGROOT)/src/collective/in_memory_communicator.o \
$(PKGROOT)/src/collective/in_memory_handler.o \
$(PKGROOT)/src/collective/loop.o \
$(PKGROOT)/src/collective/socket.o \
Expand All @@ -134,7 +131,4 @@ OBJECTS= \
$(PKGROOT)/src/common/version.o \
$(PKGROOT)/src/c_api/c_api.o \
$(PKGROOT)/src/c_api/c_api_error.o \
$(PKGROOT)/amalgamation/dmlc-minimum0.o \
$(PKGROOT)/rabit/src/engine.o \
$(PKGROOT)/rabit/src/rabit_c_api.o \
$(PKGROOT)/rabit/src/allreduce_base.o
$(PKGROOT)/amalgamation/dmlc-minimum0.o
1 change: 1 addition & 0 deletions cmake/Utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ function(xgboost_set_cuda_flags target)
target_include_directories(
${target} PRIVATE
${xgboost_SOURCE_DIR}/gputreeshap
${xgboost_SOURCE_DIR}/rabit/include
${CUDAToolkit_INCLUDE_DIRS})

if(MSVC)
Expand Down
2 changes: 1 addition & 1 deletion demo/dask/cpu_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def main(client: Client) -> None:
m = 100000
n = 100
rng = da.random.default_rng(1)
X = rng.normal(size=(m, n))
X = rng.normal(size=(m, n), chunks=(10000, -1))
y = X.sum(axis=1)

# DaskDMatrix acts like normal DMatrix, works as a proxy for local
Expand Down
184 changes: 105 additions & 79 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1117,8 +1117,8 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *values,
*
* @return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *array_interface,
char const *c_json_config, DMatrixHandle m,
XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *values,
char const *config, DMatrixHandle m,
bst_ulong const **out_shape, bst_ulong *out_dim,
const float **out_result);

Expand Down Expand Up @@ -1514,16 +1514,37 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config,
*
* @brief Experimental support for exposing internal communicator in XGBoost.
*
* @note This is still under development.
*
* The collective communicator in XGBoost evolved from the `rabit` project of dmlc but has
* changed significantly since its adoption. It consists of a tracker and a set of
* workers. The tracker is responsible for bootstrapping the communication group and
* handling centralized tasks like logging. The workers are actual communicators
* performing collective tasks like allreduce.
*
* To use the collective implementation, one needs to first create a tracker with
* corresponding parameters, then get the arguments for workers using
* XGTrackerWorkerArgs(). The obtained arguments can then be passed to the
* XGCommunicatorInit() function. Call to XGCommunicatorInit() must be accompanied with a
* XGCommunicatorFinalize() call for cleanups. Please note that the communicator uses
* `std::thread` in C++, which has undefined behavior in a C++ destructor due to the
* runtime shutdown sequence. It's preferable to call XGCommunicatorFinalize() before the
* runtime is shutting down. This requirement is similar to a Python thread or socket,
* which should not be relied upon in a `__del__` function.
*
* Since it's used as a part of XGBoost, errors will be returned when a XGBoost function
* is called, for instance, training a booster might return a connection error.
*
* @{
*/

/**
* @brief Handle to tracker.
* @brief Handle to the tracker.
*
* There are currently two types of tracker in XGBoost, first one is `rabit`, while the
* other one is `federated`.
* other one is `federated`. `rabit` is used for normal collective communication, while
* `federated` is used for federated learning.
*
* This is still under development.
*/
typedef void *TrackerHandle; /* NOLINT */

Expand All @@ -1532,17 +1553,23 @@ typedef void *TrackerHandle; /* NOLINT */
*
* @param config JSON encoded parameters.
*
* - dmlc_communicator: String, the type of tracker to create. Available options are `rabit`
* and `federated`.
* - dmlc_communicator: String, the type of tracker to create. Available options are
* `rabit` and `federated`. See @ref TrackerHandle for more info.
* - n_workers: Integer, the number of workers.
* - port: (Optional) Integer, the port this tracker should listen to.
* - timeout: (Optional) Integer, timeout in seconds for various networking operations.
* - timeout: (Optional) Integer, timeout in seconds for various networking
operations. Default is 300 seconds.
*
* Some configurations are `rabit` specific:
*
* - host: (Optional) String, Used by the the `rabit` tracker to specify the address of the host.
* This can be useful when the communicator cannot reliably obtain the host address.
* - sortby: (Optional) Integer.
* + 0: Sort workers by their host name.
* + 1: Sort workers by task IDs.
*
* Some `federated` specific configurations:
* - federated_secure: Boolean, whether this is a secure server.
* - federated_secure: Boolean, whether this is a secure server. False for testing.
* - server_key_path: Path to the server key. Used only if this is a secure server.
* - server_cert_path: Path to the server certificate. Used only if this is a secure server.
* - client_cert_path: Path to the client certificate. Used only if this is a secure server.
Expand Down Expand Up @@ -1598,129 +1625,128 @@ XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config);
*/
XGB_DLL int XGTrackerFree(TrackerHandle handle);

/*!
* \brief Initialize the collective communicator.
/**
* @brief Initialize the collective communicator.
*
* Currently the communicator API is experimental, function signatures may change in the future
* without notice.
*
* Call this once before using anything.
*
* The additional configuration is not required. Usually the communicator will detect settings
* from environment variables.
* Call this once in the worker process before using anything. Please make sure
* XGCommunicatorFinalize() is called after use. The initialized commuicator is a global
* thread-local variable.
*
* \param config JSON encoded configuration. Accepted JSON keys are:
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
* @param config JSON encoded configuration. Accepted JSON keys are:
* - dmlc_communicator: The type of the communicator, this should match the tracker type.
* * rabit: Use Rabit. This is the default if the type is unspecified.
* * federated: Use the gRPC interface for Federated Learning.
* Only applicable to the Rabit communicator (these are case-sensitive):
* - rabit_tracker_uri: Hostname of the tracker.
* - rabit_tracker_port: Port number of the tracker.
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
* - rabit_world_size: Total number of workers.
* - rabit_timeout: Enable timeout.
* - rabit_timeout_sec: Timeout in seconds.
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
* environment variables):
* - DMLC_TRACKER_URI: Hostname of the tracker.
* - DMLC_TRACKER_PORT: Port number of the tracker.
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
* - dmlc_nccl_path: The path to NCCL shared object. Only used if XGBoost is compiled with
* `USE_DLOPEN_NCCL`.
* Only applicable to the Federated communicator (use upper case for environment variables, use
*
* Only applicable to the `rabit` communicator:
* - dmlc_tracker_uri: Hostname or IP address of the tracker.
* - dmlc_tracker_port: Port number of the tracker.
* - dmlc_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
* - dmlc_retry: The number of retries for connection failure.
* - dmlc_timeout: Timeout in seconds.
* - dmlc_nccl_path: Path to the nccl shared library `libnccl.so`.
*
* Only applicable to the `federated` communicator (use upper case for environment variables, use
* lower case for runtime configuration):
* - federated_server_address: Address of the federated server.
* - federated_world_size: Number of federated workers.
* - federated_rank: Rank of the current worker.
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
* - federated_client_key: Client key file path. Only needed for the SSL mode.
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
* \return 0 for success, -1 for failure.
* - federated_server_cert_path: Server certificate file path. Only needed for the SSL mode.
* - federated_client_key_path: Client key file path. Only needed for the SSL mode.
* - federated_client_cert_path: Client certificate file path. Only needed for the SSL mode.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorInit(char const* config);

/*!
* \brief Finalize the collective communicator.
/**
* @brief Finalize the collective communicator.
*
* Call this function after you finished all jobs.
* Call this function after you have finished all jobs.
*
* \return 0 for success, -1 for failure.
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorFinalize(void);

/*!
* \brief Get rank of current process.
/**
* @brief Get rank of the current process.
*
* \return Rank of the worker.
* @return Rank of the worker.
*/
XGB_DLL int XGCommunicatorGetRank(void);

/*!
* \brief Get total number of processes.
/**
* @brief Get the total number of processes.
*
* \return Total world size.
* @return Total world size.
*/
XGB_DLL int XGCommunicatorGetWorldSize(void);

/*!
* \brief Get if the communicator is distributed.
/**
* @brief Get if the communicator is distributed.
*
* \return True if the communicator is distributed.
* @return True if the communicator is distributed.
*/
XGB_DLL int XGCommunicatorIsDistributed(void);

/*!
* \brief Print the message to the communicator.
/**
* @brief Print the message to the tracker.
*
* This function can be used to communicate the information of the progress to the user who monitors
* the communicator.
* This function can be used to communicate the information of the progress to the user
* who monitors the tracker.
*
* \param message The message to be printed.
* \return 0 for success, -1 for failure.
* @param message The message to be printed.
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorPrint(char const *message);

/*!
* \brief Get the name of the processor.
/**
* @brief Get the name of the processor.
*
* \param name_str Pointer to received returned processor name.
* \return 0 for success, -1 for failure.
* @param name_str Pointer to received returned processor name.
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str);

/*!
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
/**
* @brief Broadcast a memory region to all others from root. This function is NOT
* thread-safe.
*
* Example:
* \code
* @code
* int a = 1;
* Broadcast(&a, sizeof(a), root);
* \endcode
* @endcode
*
* \param send_receive_buffer Pointer to the send or receive buffer.
* \param size Size of the data.
* \param root The process rank to broadcast from.
* \return 0 for success, -1 for failure.
* @param send_receive_buffer Pointer to the send or receive buffer.
* @param size Size of the data in bytes.
* @param root The process rank to broadcast from.
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root);

/*!
* \brief Perform in-place allreduce. This function is NOT thread-safe.
/**
* @brief Perform in-place allreduce. This function is NOT thread-safe.
*
* Example Usage: the following code gives sum of the result
* \code
* vector<int> data(10);
* @code
* enum class Op {
* kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5
* };
* std::vector<int> data(10);
* ...
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
* Allreduce(data.data(), data.size(), DataType:kInt32, Op::kSum);
* ...
* \endcode
* @endcode
* \param send_receive_buffer Buffer for both sending and receiving data.
* \param count Number of elements to be reduced.
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
* \return 0 for success, -1 for failure.
* @param send_receive_buffer Buffer for both sending and receiving data.
* @param count Number of elements to be reduced.
* @param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
* @param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int data_type, int op);

Expand Down
5 changes: 2 additions & 3 deletions include/xgboost/collective/result.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,9 @@ struct ResultImpl {
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
#define __builtin_FILE() nullptr
#define __builtin_LINE() (-1)
std::string MakeMsg(std::string&& msg, char const*, std::int32_t);
#else
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
#endif

std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
} // namespace detail

/**
Expand Down
Loading

0 comments on commit a5a5810

Please sign in to comment.