Skip to content

Commit

Permalink
[HybridParallel] fix port reuse when create multi group (#31876)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxicoding authored Apr 26, 2021
1 parent 8fec3c6 commit 41bfec8
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 8 deletions.
22 changes: 18 additions & 4 deletions paddle/fluid/imperative/nccl_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace imperative {

void NCCLParallelContext::BcastNCCLId(
std::vector<ncclUniqueId> &nccl_ids, // NOLINT
int root) {
int root, int server_fd) {
if (strategy_.local_rank_ == root) {
std::vector<std::string> other_trainers;
for (auto &ep : strategy_.trainer_endpoints_) {
Expand All @@ -45,11 +45,14 @@ void NCCLParallelContext::BcastNCCLId(
}
platform::SendBroadCastCommID(other_trainers, &nccl_ids);
} else {
platform::RecvBroadCastCommID(strategy_.current_endpoint_, &nccl_ids);
platform::RecvBroadCastCommID(server_fd, strategy_.current_endpoint_,
&nccl_ids);
}
}

void NCCLParallelContext::Init() {
int server_fd = -1;

std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(strategy_.nrings_);

Expand All @@ -58,8 +61,13 @@ void NCCLParallelContext::Init() {
for (size_t i = 0; i < nccl_ids.size(); ++i) {
platform::dynload::ncclGetUniqueId(&nccl_ids[i]);
}
} else {
// FIXME(wangxi): gloo will use rank0 endpoint, so not create socket server
// on rank0.
server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_)
.socket();
}
BcastNCCLId(nccl_ids, 0);
BcastNCCLId(nccl_ids, 0, server_fd);

int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device;
for (int ring_id = 0; ring_id < strategy_.nrings_; ring_id++) {
Expand All @@ -80,14 +88,20 @@ void NCCLParallelContext::Init() {
}

void NCCLParallelContext::InitWithRingID(int ring_id) {
int server_fd = -1;
std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(1);

if (strategy_.local_rank_ == 0) {
// generate the unique ncclid on the root worker
platform::dynload::ncclGetUniqueId(&nccl_ids[0]);
} else {
// FIXME(wangxi): gloo will use rank0 endpoint, so not create socket server
// on rank0.
server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_)
.socket();
}
BcastNCCLId(nccl_ids, 0);
BcastNCCLId(nccl_ids, 0, server_fd);

int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device;
VLOG(0) << "init nccl context nranks: " << strategy_.nranks_
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/imperative/nccl_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class NCCLParallelContext : public ParallelContext {

~NCCLParallelContext() override = default;

void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, int root); // NOLINT
void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, int root, // NOLINT
int server_fd);

void Init() override;

Expand Down
7 changes: 6 additions & 1 deletion paddle/fluid/imperative/tests/nccl_context_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <thread> // NOLINT

#include "paddle/fluid/imperative/nccl_context.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"

#include "gtest/gtest.h"

Expand All @@ -36,9 +37,13 @@ imperative::ParallelStrategy GetStrategy(int local_rank) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
void BcastNCCLId(int local_rank, std::vector<ncclUniqueId>* nccl_ids) {
auto strategy = GetStrategy(local_rank);
int server_fd = platform::CreateListenSocket(strategy.current_endpoint_);

platform::CUDAPlace gpu(local_rank);
imperative::NCCLParallelContext ctx(strategy, gpu);
ctx.BcastNCCLId(*nccl_ids, 0);
ctx.BcastNCCLId(*nccl_ids, 0, server_fd);

platform::CloseSocket(server_fd);
}

TEST(BcastNCCLId, Run) {
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class CGenNCCLIdOp : public framework::OperatorBase {
return Output("Out");
};

std::string endpoint = Attr<std::string>("endpoint");
int server_fd = platform::SocketServer::GetInstance(endpoint).socket();

std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(1);

Expand All @@ -75,8 +78,6 @@ class CGenNCCLIdOp : public framework::OperatorBase {
Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &nccl_ids);
} else {
std::string endpoint = Attr<std::string>("endpoint");
int server_fd = platform::SocketServer::GetInstance(endpoint).socket();
platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids);
}

Expand Down

0 comments on commit 41bfec8

Please sign in to comment.