Skip to content

Commit

Permalink
Refactor device communicator to make allreduce more flexible (#9295)
Browse files Browse the repository at this point in the history
  • Loading branch information
rongou authored Jun 13, 2023
1 parent c2f0486 commit e70810b
Show file tree
Hide file tree
Showing 11 changed files with 190 additions and 106 deletions.
81 changes: 81 additions & 0 deletions src/collective/communicator-inl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#pragma once
#include <string>
#include <vector>

#include "communicator.h"
#include "device_communicator.cuh"

namespace xgboost {
namespace collective {

/**
* @brief Reduce values from all processes and distribute the result back to all processes.
* @param device ID of the device.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
template <Operation op>
inline void AllReduce(int device, std::int8_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
}

template <Operation op>
inline void AllReduce(int device, std::uint8_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
}

template <Operation op>
inline void AllReduce(int device, std::int32_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
}

template <Operation op>
inline void AllReduce(int device, std::uint32_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
}

template <Operation op>
inline void AllReduce(int device, std::int64_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
}

template <Operation op>
inline void AllReduce(int device, std::uint64_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
}

template <Operation op>
inline void AllReduce(int device, float *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
}

template <Operation op>
inline void AllReduce(int device, double *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
}

/**
* @brief Gather variable-length values from all processes.
* @param device ID of the device.
* @param send_buffer Buffer storing the input data.
* @param length_bytes Length in bytes of the input data.
* @param segments Size of each segment.
* @param receive_buffer Buffer storing the output data.
*/
inline void AllGatherV(int device, void const *send_buffer, size_t length_bytes,
std::vector<size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) {
Communicator::GetDevice(device)->AllGatherV(send_buffer, length_bytes, segments, receive_buffer);
}

/**
* @brief Synchronize device operations.
* @param device ID of the device.
*/
inline void Synchronize(int device) { Communicator::GetDevice(device)->Synchronize(); }

} // namespace collective
} // namespace xgboost
29 changes: 6 additions & 23 deletions src/collective/device_communicator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,15 @@ class DeviceCommunicator {
virtual ~DeviceCommunicator() = default;

/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @brief Combines values from all processes and distributes the result back to all processes.
*
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
* @param data_type Data type stored in the buffer.
* @param op The operation to perform.
*/
virtual void AllReduceSum(float *send_receive_buffer, size_t count) = 0;

/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(double *send_receive_buffer, size_t count) = 0;

/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(int64_t *send_receive_buffer, size_t count) = 0;

/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(uint64_t *send_receive_buffer, size_t count) = 0;
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) = 0;

/**
* @brief Gather variable-length values from all processes.
Expand Down
38 changes: 11 additions & 27 deletions src/collective/device_communicator_adapter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,18 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {

~DeviceCommunicatorAdapter() override = default;

void AllReduceSum(float *send_receive_buffer, size_t count) override {
DoAllReduceSum<collective::DataType::kFloat>(send_receive_buffer, count);
}

void AllReduceSum(double *send_receive_buffer, size_t count) override {
DoAllReduceSum<collective::DataType::kDouble>(send_receive_buffer, count);
}

void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<collective::DataType::kInt64>(send_receive_buffer, count);
}
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
if (communicator_->GetWorldSize() == 1) {
return;
}

void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<collective::DataType::kUInt64>(send_receive_buffer, count);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto size = count * GetTypeSize(data_type);
host_buffer_.reserve(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
communicator_->AllReduce(host_buffer_.data(), count, data_type, op);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
}

void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
Expand Down Expand Up @@ -77,20 +75,6 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
}

private:
template <collective::DataType data_type, typename T>
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
if (communicator_->GetWorldSize() == 1) {
return;
}

dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto size = count * sizeof(T);
host_buffer_.reserve(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
}

int const device_ordinal_;
Communicator *communicator_;
/// Host buffer used to call communicator functions.
Expand Down
84 changes: 62 additions & 22 deletions src/collective/nccl_device_communicator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,18 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
}
}

void AllReduceSum(float *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclFloat>(send_receive_buffer, count);
}

void AllReduceSum(double *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclDouble>(send_receive_buffer, count);
}

void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclInt64>(send_receive_buffer, count);
}
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
if (communicator_->GetWorldSize() == 1) {
return;
}

void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclUint64>(send_receive_buffer, count);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
cuda_stream_));
allreduce_bytes_ += count * GetTypeSize(data_type);
allreduce_calls_ += 1;
}

void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
Expand Down Expand Up @@ -162,17 +160,59 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
return id;
}

template <ncclDataType_t data_type, typename T>
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
if (communicator_->GetWorldSize() == 1) {
return;
static ncclDataType_t GetNcclDataType(DataType const &data_type) {
ncclDataType_t result;
switch (data_type) {
case DataType::kInt8:
result = ncclInt8;
break;
case DataType::kUInt8:
result = ncclUint8;
break;
case DataType::kInt32:
result = ncclInt32;
break;
case DataType::kUInt32:
result = ncclUint32;
break;
case DataType::kInt64:
result = ncclInt64;
break;
case DataType::kUInt64:
result = ncclUint64;
break;
case DataType::kFloat:
result = ncclFloat;
break;
case DataType::kDouble:
result = ncclDouble;
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return result;
}

dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum,
nccl_comm_, cuda_stream_));
allreduce_bytes_ += count * sizeof(T);
allreduce_calls_ += 1;
static ncclRedOp_t GetNcclRedOp(Operation const &op) {
ncclRedOp_t result;
switch (op) {
case Operation::kMax:
result = ncclMax;
break;
case Operation::kMin:
result = ncclMin;
break;
case Operation::kSum:
result = ncclSum;
break;
case Operation::kBitwiseAND:
case Operation::kBitwiseOR:
case Operation::kBitwiseXOR:
LOG(FATAL) << "Not implemented yet.";
default:
LOG(FATAL) << "Unknown reduce operation.";
}
return result;
}

int const device_ordinal_;
Expand Down
13 changes: 6 additions & 7 deletions src/common/quantile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
#include <memory>
#include <utility>

#include "../collective/communicator.h"
#include "../collective/device_communicator.cuh"
#include "../collective/communicator-inl.cuh"
#include "categorical.h"
#include "common.h"
#include "device_helpers.cuh"
Expand Down Expand Up @@ -510,7 +509,6 @@ void SketchContainer::AllReduce() {
}

timer_.Start(__func__);
auto* communicator = collective::Communicator::GetDevice(device_);
// Reduce the overhead on syncing.
size_t global_sum_rows = num_rows_;
collective::Allreduce<collective::Operation::kSum>(&global_sum_rows, 1);
Expand All @@ -531,14 +529,15 @@ void SketchContainer::AllReduce() {
auto offset = rank * d_columns_ptr.size();
thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(),
gathered_ptrs.begin() + offset);
communicator->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.size());
collective::AllReduce<collective::Operation::kSum>(device_, gathered_ptrs.data().get(),
gathered_ptrs.size());

// Get the data from all workers.
std::vector<size_t> recv_lengths;
dh::caching_device_vector<char> recvbuf;
communicator->AllGatherV(this->Current().data().get(), dh::ToSpan(this->Current()).size_bytes(),
&recv_lengths, &recvbuf);
communicator->Synchronize();
collective::AllGatherV(device_, this->Current().data().get(),
dh::ToSpan(this->Current()).size_bytes(), &recv_lengths, &recvbuf);
collective::Synchronize(device_);

// Segment the received data.
auto s_recvbuf = dh::ToSpan(recvbuf);
Expand Down
5 changes: 2 additions & 3 deletions src/metric/auc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include <tuple>
#include <utility>

#include "../collective/device_communicator.cuh"
#include "../collective/communicator-inl.cuh"
#include "../common/algorithm.cuh" // SegmentedArgSort
#include "../common/optional_weight.h" // OptionalWeights
#include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads
Expand Down Expand Up @@ -205,8 +205,7 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
if (collective::IsDistributed()) {
int32_t device = dh::CurrentDevice();
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
auto* communicator = collective::Communicator::GetDevice(device);
communicator->AllReduceSum(results.data(), results.size());
collective::AllReduce<collective::Operation::kSum>(device, results.data(), results.size());
}
auto reduce_in = dh::MakeTransformIterator<Pair>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
Expand Down
6 changes: 3 additions & 3 deletions src/tree/fit_stump.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

#include <cstddef> // std::size_t

#include "../collective/device_communicator.cuh" // DeviceCommunicator
#include "../collective/communicator-inl.cuh"
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator
#include "fit_stump.h"
#include "xgboost/base.h" // GradientPairPrecise, GradientPair, XGBOOST_DEVICE
Expand Down Expand Up @@ -49,8 +49,8 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it,
thrust::make_discard_iterator(), dh::tbegin(d_sum.Values()));

collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(ctx->gpu_id);
communicator->AllReduceSum(reinterpret_cast<double*>(d_sum.Values().data()), d_sum.Size() * 2);
collective::AllReduce<collective::Operation::kSum>(
ctx->gpu_id, reinterpret_cast<double*>(d_sum.Values().data()), d_sum.Size() * 2);

thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets,
[=] XGBOOST_DEVICE(std::size_t i) mutable {
Expand Down
Loading

0 comments on commit e70810b

Please sign in to comment.