diff --git a/src/collective/communicator.cu b/src/collective/communicator.cu index 8bd10382def5..8cdb7f2fd3bd 100644 --- a/src/collective/communicator.cu +++ b/src/collective/communicator.cu @@ -30,12 +30,12 @@ DeviceCommunicator* Communicator::GetDevice(int device_ordinal) { old_world_size = communicator_->GetWorldSize(); #ifdef XGBOOST_USE_NCCL if (type_ != CommunicatorType::kFederated) { - device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, Get())); + device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal)); } else { - device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get())); + device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal)); } #else - device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get())); + device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal)); #endif } return device_communicator_.get(); diff --git a/src/collective/device_communicator_adapter.cuh b/src/collective/device_communicator_adapter.cuh index 06637c5b4768..f8135fb9473f 100644 --- a/src/collective/device_communicator_adapter.cuh +++ b/src/collective/device_communicator_adapter.cuh @@ -11,21 +11,18 @@ namespace collective { class DeviceCommunicatorAdapter : public DeviceCommunicator { public: - DeviceCommunicatorAdapter(int device_ordinal, Communicator *communicator) - : device_ordinal_{device_ordinal}, communicator_{communicator} { + explicit DeviceCommunicatorAdapter(int device_ordinal) + : device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} { if (device_ordinal_ < 0) { LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_; } - if (communicator_ == nullptr) { - LOG(FATAL) << "Communicator cannot be null."; - } } ~DeviceCommunicatorAdapter() override = default; void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op) override { - if (communicator_->GetWorldSize() == 1) { + if (world_size_ == 1) { return; } @@ -33,37 +30,34 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator { auto size = count * GetTypeSize(data_type); host_buffer_.reserve(size); dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault)); - communicator_->AllReduce(host_buffer_.data(), count, data_type, op); + Allreduce(host_buffer_.data(), count, data_type, op); dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault)); } void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector *segments, dh::caching_device_vector *receive_buffer) override { - if (communicator_->GetWorldSize() == 1) { + if (world_size_ == 1) { return; } dh::safe_cuda(cudaSetDevice(device_ordinal_)); - int const world_size = communicator_->GetWorldSize(); - int const rank = communicator_->GetRank(); segments->clear(); - segments->resize(world_size, 0); - segments->at(rank) = length_bytes; - communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64, - Operation::kMax); + segments->resize(world_size_, 0); + segments->at(rank_) = length_bytes; + Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax); auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL); receive_buffer->resize(total_bytes); host_buffer_.reserve(total_bytes); size_t offset = 0; - for (int32_t i = 0; i < world_size; ++i) { + for (int32_t i = 0; i < world_size_; ++i) { size_t as_bytes = segments->at(i); - if (i == rank) { - dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank), + if (i == rank_) { + dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_), cudaMemcpyDefault)); } - communicator_->Broadcast(host_buffer_.data() + offset, as_bytes, i); + Broadcast(host_buffer_.data() + offset, as_bytes, i); offset += as_bytes; } dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes, @@ -76,7 +70,8 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator { private: int const device_ordinal_; - Communicator *communicator_; + int const world_size_; + int const rank_; /// Host buffer used to call communicator functions. std::vector host_buffer_{}; }; diff --git a/src/collective/nccl_device_communicator.cu b/src/collective/nccl_device_communicator.cu index 631193db4d86..57419b947656 100644 --- a/src/collective/nccl_device_communicator.cu +++ b/src/collective/nccl_device_communicator.cu @@ -7,31 +7,24 @@ namespace xgboost { namespace collective { -NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, Communicator *communicator) - : device_ordinal_{device_ordinal}, communicator_{communicator} { +NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal) + : device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} { if (device_ordinal_ < 0) { LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_; } - if (communicator_ == nullptr) { - LOG(FATAL) << "Communicator cannot be null."; - } - - int32_t const rank = communicator_->GetRank(); - int32_t const world = communicator_->GetWorldSize(); - - if (world == 1) { + if (world_size_ == 1) { return; } - std::vector uuids(world * kUuidLength, 0); + std::vector uuids(world_size_ * kUuidLength, 0); auto s_uuid = xgboost::common::Span{uuids.data(), uuids.size()}; - auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength); + auto s_this_uuid = s_uuid.subspan(rank_ * kUuidLength, kUuidLength); GetCudaUUID(s_this_uuid); // TODO(rongou): replace this with allgather. - communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum); + Allreduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum); - std::vector> converted(world); + std::vector> converted(world_size_); size_t j = 0; for (size_t i = 0; i < uuids.size(); i += kUuidLength) { converted[j] = xgboost::common::Span{uuids.data() + i, kUuidLength}; @@ -41,18 +34,18 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, Communicator auto iter = std::unique(converted.begin(), converted.end()); auto n_uniques = std::distance(converted.begin(), iter); - CHECK_EQ(n_uniques, world) + CHECK_EQ(n_uniques, world_size_) << "Multiple processes within communication group running on same CUDA " << "device is not supported. " << PrintUUID(s_this_uuid) << "\n"; nccl_unique_id_ = GetUniqueId(); dh::safe_cuda(cudaSetDevice(device_ordinal_)); - dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank)); + dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_)); dh::safe_cuda(cudaStreamCreate(&cuda_stream_)); } NcclDeviceCommunicator::~NcclDeviceCommunicator() { - if (communicator_->GetWorldSize() == 1) { + if (world_size_ == 1) { return; } if (cuda_stream_) { @@ -139,9 +132,8 @@ void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op) { - auto const world_size = communicator_->GetWorldSize(); auto const size = count * GetTypeSize(data_type); - dh::caching_device_vector buffer(size * world_size); + dh::caching_device_vector buffer(size * world_size_); auto *device_buffer = buffer.data().get(); // First gather data from all the workers. @@ -152,15 +144,15 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si auto *out_buffer = static_cast(send_receive_buffer); switch (op) { case Operation::kBitwiseAND: - RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and(), world_size, size, + RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and(), world_size_, size, cuda_stream_); break; case Operation::kBitwiseOR: - RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or(), world_size, size, + RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or(), world_size_, size, cuda_stream_); break; case Operation::kBitwiseXOR: - RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor(), world_size, size, + RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor(), world_size_, size, cuda_stream_); break; default: @@ -170,7 +162,7 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op) { - if (communicator_->GetWorldSize() == 1) { + if (world_size_ == 1) { return; } @@ -189,24 +181,22 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes, std::vector *segments, dh::caching_device_vector *receive_buffer) { - if (communicator_->GetWorldSize() == 1) { + if (world_size_ == 1) { return; } dh::safe_cuda(cudaSetDevice(device_ordinal_)); - int const world_size = communicator_->GetWorldSize(); - int const rank = communicator_->GetRank(); segments->clear(); - segments->resize(world_size, 0); - segments->at(rank) = length_bytes; - communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax); + segments->resize(world_size_, 0); + segments->at(rank_) = length_bytes; + Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax); auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL); receive_buffer->resize(total_bytes); size_t offset = 0; dh::safe_nccl(ncclGroupStart()); - for (int32_t i = 0; i < world_size; ++i) { + for (int32_t i = 0; i < world_size_; ++i) { size_t as_bytes = segments->at(i); dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes, ncclChar, i, nccl_comm_, cuda_stream_)); @@ -216,7 +206,7 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b } void NcclDeviceCommunicator::Synchronize() { - if (communicator_->GetWorldSize() == 1) { + if (world_size_ == 1) { return; } dh::safe_cuda(cudaSetDevice(device_ordinal_)); diff --git a/src/collective/nccl_device_communicator.cuh b/src/collective/nccl_device_communicator.cuh index e5f76119d914..925603d21252 100644 --- a/src/collective/nccl_device_communicator.cuh +++ b/src/collective/nccl_device_communicator.cuh @@ -12,7 +12,7 @@ namespace collective { class NcclDeviceCommunicator : public DeviceCommunicator { public: - NcclDeviceCommunicator(int device_ordinal, Communicator *communicator); + explicit NcclDeviceCommunicator(int device_ordinal); ~NcclDeviceCommunicator() override; void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op) override; @@ -49,11 +49,10 @@ class NcclDeviceCommunicator : public DeviceCommunicator { ncclUniqueId GetUniqueId() { static const int kRootRank = 0; ncclUniqueId id; - if (communicator_->GetRank() == kRootRank) { + if (rank_ == kRootRank) { dh::safe_nccl(ncclGetUniqueId(&id)); } - communicator_->Broadcast(static_cast(&id), sizeof(ncclUniqueId), - static_cast(kRootRank)); + Broadcast(static_cast(&id), sizeof(ncclUniqueId), static_cast(kRootRank)); return id; } @@ -61,7 +60,8 @@ class NcclDeviceCommunicator : public DeviceCommunicator { Operation op); int const device_ordinal_; - Communicator *communicator_; + int const world_size_; + int const rank_; ncclComm_t nccl_comm_{}; cudaStream_t cuda_stream_{}; ncclUniqueId nccl_unique_id_{}; diff --git a/tests/cpp/collective/test_nccl_device_communicator.cu b/tests/cpp/collective/test_nccl_device_communicator.cu index 6ac861a55877..81dd3d46db0d 100644 --- a/tests/cpp/collective/test_nccl_device_communicator.cu +++ b/tests/cpp/collective/test_nccl_device_communicator.cu @@ -16,12 +16,7 @@ namespace xgboost { namespace collective { TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) { - auto construct = []() { NcclDeviceCommunicator comm{-1, nullptr}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidCommunicator) { - auto construct = []() { NcclDeviceCommunicator comm{0, nullptr}; }; + auto construct = []() { NcclDeviceCommunicator comm{-1}; }; EXPECT_THROW(construct(), dmlc::Error); } diff --git a/tests/cpp/plugin/helpers.h b/tests/cpp/plugin/helpers.h index c4d303bb5437..20b4afc3026b 100644 --- a/tests/cpp/plugin/helpers.h +++ b/tests/cpp/plugin/helpers.h @@ -37,7 +37,14 @@ class ServerForTest { } ~ServerForTest() { + using namespace std::chrono_literals; + while (!server_) { + std::this_thread::sleep_for(100ms); + } server_->Shutdown(); + while (!server_thread_) { + std::this_thread::sleep_for(100ms); + } server_thread_->join(); } @@ -56,7 +63,7 @@ class BaseFederatedTest : public ::testing::Test { void TearDown() override { server_.reset(nullptr); } - static int constexpr kWorldSize{3}; + static int constexpr kWorldSize{2}; std::unique_ptr server_; }; diff --git a/tests/cpp/plugin/test_federated_adapter.cu b/tests/cpp/plugin/test_federated_adapter.cu index 3fb793fa7160..134446f11350 100644 --- a/tests/cpp/plugin/test_federated_adapter.cu +++ b/tests/cpp/plugin/test_federated_adapter.cu @@ -9,6 +9,7 @@ #include #include "../../../plugin/federated/federated_communicator.h" +#include "../../../src/collective/communicator-inl.cuh" #include "../../../src/collective/device_communicator_adapter.cuh" #include "./helpers.h" @@ -17,67 +18,63 @@ namespace xgboost::collective { class FederatedAdapterTest : public BaseFederatedTest {}; TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) { - auto construct = []() { DeviceCommunicatorAdapter adapter{-1, nullptr}; }; + auto construct = []() { DeviceCommunicatorAdapter adapter{-1}; }; EXPECT_THROW(construct(), dmlc::Error); } -TEST(FederatedAdapterSimpleTest, ThrowOnInvalidCommunicator) { - auto construct = []() { DeviceCommunicatorAdapter adapter{0, nullptr}; }; - EXPECT_THROW(construct(), dmlc::Error); +namespace { +void VerifyAllReduceSum() { + auto const world_size = collective::GetWorldSize(); + auto const rank = collective::GetRank(); + int count = 3; + thrust::device_vector buffer(count, 0); + thrust::sequence(buffer.begin(), buffer.end()); + collective::AllReduce(rank, buffer.data().get(), count); + thrust::host_vector host_buffer = buffer; + EXPECT_EQ(host_buffer.size(), count); + for (auto i = 0; i < count; i++) { + EXPECT_EQ(host_buffer[i], i * world_size); + } } +} // anonymous namespace -TEST_F(FederatedAdapterTest, DeviceAllReduceSum) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back([rank, server_address = server_->Address()] { - FederatedCommunicator comm{kWorldSize, rank, server_address}; - // Assign device 0 to all workers, since we run gtest in a single-GPU machine - DeviceCommunicatorAdapter adapter{0, &comm}; - int count = 3; - thrust::device_vector buffer(count, 0); - thrust::sequence(buffer.begin(), buffer.end()); - adapter.AllReduce(buffer.data().get(), count, DataType::kDouble, Operation::kSum); - thrust::host_vector host_buffer = buffer; - EXPECT_EQ(host_buffer.size(), count); - for (auto i = 0; i < count; i++) { - EXPECT_EQ(host_buffer[i], i * kWorldSize); - } - }); - } - for (auto& thread : threads) { - thread.join(); +TEST_F(FederatedAdapterTest, MGPUAllReduceSum) { + auto const n_gpus = common::AllVisibleGPUs(); + if (n_gpus <= 1) { + GTEST_SKIP() << "Skipping MGPUAllReduceSum test with # GPUs = " << n_gpus; } + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllReduceSum); } -TEST_F(FederatedAdapterTest, DeviceAllGatherV) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back([rank, server_address = server_->Address()] { - FederatedCommunicator comm{kWorldSize, rank, server_address}; - // Assign device 0 to all workers, since we run gtest in a single-GPU machine - DeviceCommunicatorAdapter adapter{0, &comm}; +namespace { +void VerifyAllGatherV() { + auto const world_size = collective::GetWorldSize(); + auto const rank = collective::GetRank(); + int const count = rank + 2; + thrust::device_vector buffer(count, 0); + thrust::sequence(buffer.begin(), buffer.end()); + std::vector segments(world_size); + dh::caching_device_vector receive_buffer{}; - int const count = rank + 2; - thrust::device_vector buffer(count, 0); - thrust::sequence(buffer.begin(), buffer.end()); - std::vector segments(kWorldSize); - dh::caching_device_vector receive_buffer{}; + collective::AllGatherV(rank, buffer.data().get(), count, &segments, &receive_buffer); - adapter.AllGatherV(buffer.data().get(), count, &segments, &receive_buffer); - - EXPECT_EQ(segments[0], 2); - EXPECT_EQ(segments[1], 3); - thrust::host_vector host_buffer = receive_buffer; - EXPECT_EQ(host_buffer.size(), 9); - int expected[] = {0, 1, 0, 1, 2, 0, 1, 2, 3}; - for (auto i = 0; i < 9; i++) { - EXPECT_EQ(host_buffer[i], expected[i]); - } - }); + EXPECT_EQ(segments[0], 2); + EXPECT_EQ(segments[1], 3); + thrust::host_vector host_buffer = receive_buffer; + EXPECT_EQ(host_buffer.size(), 5); + int expected[] = {0, 1, 0, 1, 2}; + for (auto i = 0; i < 5; i++) { + EXPECT_EQ(host_buffer[i], expected[i]); } - for (auto& thread : threads) { - thread.join(); +} +} // anonymous namespace + +TEST_F(FederatedAdapterTest, MGPUAllGatherV) { + auto const n_gpus = common::AllVisibleGPUs(); + if (n_gpus <= 1) { + GTEST_SKIP() << "Skipping MGPUAllGatherV test with # GPUs = " << n_gpus; } + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllGatherV); } } // namespace xgboost::collective diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc index 62f33d5ee29a..8b0e1039adff 100644 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -31,7 +31,7 @@ class FederatedCommunicatorTest : public BaseFederatedTest { protected: static void CheckAllgather(FederatedCommunicator &comm, int rank) { - int buffer[kWorldSize] = {0, 0, 0}; + int buffer[kWorldSize] = {0, 0}; buffer[rank] = rank; comm.AllGather(buffer, sizeof(buffer)); for (auto i = 0; i < kWorldSize; i++) { @@ -42,7 +42,7 @@ class FederatedCommunicatorTest : public BaseFederatedTest { static void CheckAllreduce(FederatedCommunicator &comm) { int buffer[] = {1, 2, 3, 4, 5}; comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum); - int expected[] = {3, 6, 9, 12, 15}; + int expected[] = {2, 4, 6, 8, 10}; for (auto i = 0; i < 5; i++) { EXPECT_EQ(buffer[i], expected[i]); } diff --git a/tests/cpp/plugin/test_federated_data.cc b/tests/cpp/plugin/test_federated_data.cc index c6efb84d5a77..6a8233a0fefe 100644 --- a/tests/cpp/plugin/test_federated_data.cc +++ b/tests/cpp/plugin/test_federated_data.cc @@ -30,7 +30,7 @@ void VerifyLoadUri() { std::string uri = path + "?format=csv"; dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kCol)); - ASSERT_EQ(dmat->Info().num_col_, 8 * collective::GetWorldSize() + 3); + ASSERT_EQ(dmat->Info().num_col_, 8 * collective::GetWorldSize() + 1); ASSERT_EQ(dmat->Info().num_row_, kRows); for (auto const& page : dmat->GetBatches()) { diff --git a/tests/cpp/plugin/test_federated_server.cc b/tests/cpp/plugin/test_federated_server.cc index 4dd2f3c4031a..633d64df10f8 100644 --- a/tests/cpp/plugin/test_federated_server.cc +++ b/tests/cpp/plugin/test_federated_server.cc @@ -39,7 +39,7 @@ class FederatedServerTest : public BaseFederatedTest { protected: static void CheckAllgather(federated::FederatedClient& client, int rank) { - int data[kWorldSize] = {0, 0, 0}; + int data[kWorldSize] = {0, 0}; data[rank] = rank; std::string send_buffer(reinterpret_cast(data), sizeof(data)); auto reply = client.Allgather(send_buffer); @@ -54,7 +54,7 @@ class FederatedServerTest : public BaseFederatedTest { std::string send_buffer(reinterpret_cast(data), sizeof(data)); auto reply = client.Allreduce(send_buffer, federated::INT32, federated::SUM); auto const* result = reinterpret_cast(reply.data()); - int expected[] = {3, 6, 9, 12, 15}; + int expected[] = {2, 4, 6, 8, 10}; for (auto i = 0; i < 5; i++) { EXPECT_EQ(result[i], expected[i]); }