Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hybrid] remove the using of global ring in hybrid parallel #34525

Merged
merged 10 commits into from
Aug 3, 2021
4 changes: 2 additions & 2 deletions paddle/fluid/imperative/bkcl_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void BKCLParallelContext::Init() {
<< " local rank: " << strategy_.local_rank_ << " xpu id: " << xpu_id
<< " ring id: " << ring_id;
// it will assign bkcl_comm in XPUDeviceContext within ring_id
platform::BKCLCommContext::Instance().CreateBKCLComm(
platform::BKCLCommContext::Instance().CreateComm(
&bkcl_ids[ring_id], strategy_.nranks_, strategy_.local_rank_, xpu_id,
ring_id);
}
Expand All @@ -116,7 +116,7 @@ void BKCLParallelContext::InitWithRingID(int ring_id) {
<< " local rank: " << strategy_.local_rank_ << " xpu id: " << xpu_id
<< " ring id: " << ring_id;
// it will assign bkcl_comm in XPUDeviceContext within ring_id
platform::BKCLCommContext::Instance().CreateBKCLComm(
platform::BKCLCommContext::Instance().CreateComm(
&bkcl_ids[0], strategy_.nranks_, strategy_.local_rank_, xpu_id, ring_id);
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/imperative/nccl_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void NCCLParallelContext::Init() {
<< " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id
<< " ring id: " << ring_id;
// it will assign nccl_comm in CUDADeviceContext within ring_id
platform::NCCLCommContext::Instance().CreateNCCLComm(
platform::NCCLCommContext::Instance().CreateComm(
&nccl_ids[ring_id], strategy_.nranks_, strategy_.local_rank_, gpu_id,
ring_id);

Expand Down Expand Up @@ -108,7 +108,7 @@ void NCCLParallelContext::InitWithRingID(int ring_id) {
<< " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id
<< " ring id: " << ring_id;
// it will assign nccl_comm in CUDADeviceContext within ring_id
platform::NCCLCommContext::Instance().CreateNCCLComm(
platform::NCCLCommContext::Instance().CreateComm(
&nccl_ids[0], strategy_.nranks_, strategy_.local_rank_, gpu_id, ring_id);

compute_events_.emplace_back(platform::CudaEventResourcePool::Instance().New(
Expand Down
84 changes: 40 additions & 44 deletions paddle/fluid/operators/collective/c_comm_init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif

namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif

namespace paddle {
namespace operators {
Expand All @@ -46,56 +47,51 @@ class CCommInitOp : public framework::OperatorBase {

void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
// TODO(wangxi): Put this in the unified header file
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
using UniqueId = ncclUniqueId;
using Place = platform::CUDAPlace;
using CommContext = platform::NCCLCommContext;
#elif defined(PADDLE_WITH_XPU_BKCL)
using UniqueId = BKCLUniqueId;
using Place = platform::XPUPlace;
using CommContext = platform::BKCLCommContext;
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should be compiled with GPU or XPU."));
#endif

PADDLE_ENFORCE_EQ(is_gpu_place(place) || is_xpu_place(place), true,
platform::errors::PreconditionNotMet(
"CCommInitOp can run on gpu or xpu place only."));

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
auto var = scope.FindVar(Input("X"));
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input con not be empty."));
if (is_gpu_place(place)) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
ncclUniqueId* nccl_id = var->GetMutable<ncclUniqueId>();

int nranks = Attr<int>("nranks");
int rank_id = Attr<int>("rank");
int rid = Attr<int>("ring_id");
int device_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id");
}
platform::NCCLCommContext::Instance().CreateNCCLComm(
nccl_id, nranks, rank_id, device_id, rid);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should be compiled with GPU."));
#endif
} else if (is_xpu_place(place)) {

UniqueId* comm_id = var->GetMutable<UniqueId>();

int nranks = Attr<int>("nranks");
int rank_id = Attr<int>("rank");
int rid = Attr<int>("ring_id");

#if defined(PADDLE_WITH_XPU_BKCL)
BKCLUniqueId* bkcl_id = var->GetMutable<BKCLUniqueId>();

int nranks = Attr<int>("nranks");
int rank_id = Attr<int>("rank");
int rid = Attr<int>("ring_id");
PADDLE_ENFORCE_EQ(
rid, 0,
platform::errors::OutOfRange(
"Ring id must equal 0 in multi Kunlun cards training, but got %d",
rid));
int device_id = BOOST_GET_CONST(platform::XPUPlace, place).device;
if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id");
}
platform::BKCLCommContext::Instance().CreateBKCLComm(
bkcl_id, nranks, rank_id, device_id, rid);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should be compiled with XPU."));
PADDLE_ENFORCE_EQ(
rid, 0,
platform::errors::OutOfRange(
"Ring id must equal 0 in multi Kunlun cards training, but got %d",
rid));
#endif
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"CCommInitOp can run on gpu or xpu place only."));

int device_id = BOOST_GET_CONST(Place, place).device;
if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id");
}
CommContext::Instance().CreateComm(comm_id, nranks, rank_id, device_id,
rid);
#endif
}
};

Expand Down
9 changes: 5 additions & 4 deletions paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class CGenBKCLIdOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
int rank = Attr<int>("rank");
framework::Scope& local_scope = scope.NewScope();
int ring_id = Attr<int>("ring_id");

std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out");
Expand All @@ -75,14 +75,13 @@ class CGenBKCLIdOp : public framework::OperatorBase {
GenBKCLID(&bkcl_ids);
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &bkcl_ids);
platform::SendBroadCastCommID(endpoint_list, &bkcl_ids, ring_id);
} else {
std::string endpoint = Attr<std::string>("endpoint");
platform::RecvBroadCastCommID(endpoint, &bkcl_ids);
platform::RecvBroadCastCommID(endpoint, &bkcl_ids, ring_id);
}

CopyBKCLIDToVar(bkcl_ids, func, scope);
scope.DeleteScope(&local_scope);
}
};

Expand All @@ -108,6 +107,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) "
"The rank of the trainer in distributed training.")
.SetDefault(0);
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
}
};

Expand Down
9 changes: 5 additions & 4 deletions paddle/fluid/operators/collective/c_gen_hccl_id_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class CGenHCCLIdOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
int rank = Attr<int>("rank");
framework::Scope& local_scope = scope.NewScope();
int ring_id = Attr<int>("ring_id");

std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out");
Expand All @@ -79,13 +79,12 @@ class CGenHCCLIdOp : public framework::OperatorBase {
GenHCCLID(&hccl_ids);
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &hccl_ids);
platform::SendBroadCastCommID(endpoint_list, &hccl_ids, ring_id);
} else {
platform::RecvBroadCastCommID(server_fd, endpoint, &hccl_ids);
platform::RecvBroadCastCommID(server_fd, endpoint, &hccl_ids, ring_id);
}

CopyHCCLIDToVar(hccl_ids, func, scope);
scope.DeleteScope(&local_scope);
}
};

Expand Down Expand Up @@ -128,6 +127,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) "
"The rank of the trainer in distributed training.")
.SetDefault(0);
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
}
};

Expand Down
9 changes: 5 additions & 4 deletions paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
int rank = Attr<int>("rank");
framework::Scope& local_scope = scope.NewScope();
int ring_id = Attr<int>("ring_id");

std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out");
Expand All @@ -76,13 +76,12 @@ class CGenNCCLIdOp : public framework::OperatorBase {
GenNCCLID(&nccl_ids);
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &nccl_ids);
platform::SendBroadCastCommID(endpoint_list, &nccl_ids, ring_id);
} else {
platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids);
platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids, ring_id);
}

CopyNCCLIDToVar(nccl_ids, func, scope);
scope.DeleteScope(&local_scope);
}
};

Expand Down Expand Up @@ -123,6 +122,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) "
"The rank of the trainer in distributed training.")
.SetDefault(0);
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
}
};

Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/platform/collective_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ class NCCLCommImpl : public NCCLComm {
std::shared_ptr<platform::CudaEventObject> comm_event_;
};

NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks,
int rank, int dev_id, int ring_id) {
NCCLComm* NCCLCommContext::CreateComm(ncclUniqueId* nccl_id, int nranks,
int rank, int dev_id, int ring_id) {
PADDLE_ENFORCE_NOT_NULL(nccl_id,
platform::errors::InvalidArgument(
"The nccl unique id should not be null."));
Expand Down Expand Up @@ -225,8 +225,8 @@ class BKCLCommImpl : public BKCLComm {
std::unique_ptr<XPUDeviceContext> dev_ctx_;
};

BKCLComm* BKCLCommContext::CreateBKCLComm(BKCLUniqueId* bkcl_id, int nranks,
int rank, int dev_id, int ring_id) {
BKCLComm* BKCLCommContext::CreateComm(BKCLUniqueId* bkcl_id, int nranks,
int rank, int dev_id, int ring_id) {
PADDLE_ENFORCE_NOT_NULL(bkcl_id,
platform::errors::InvalidArgument(
"The bkcl unique id should not be null."));
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/platform/collective_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ class NCCLCommContext {
return comm_ctx;
}

NCCLComm* CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, int rank,
int dev_id, int ring_id = 0);
NCCLComm* CreateComm(ncclUniqueId* nccl_id, int nranks, int rank, int dev_id,
int ring_id = 0);

void CreateAllNCCLComms(const std::vector<int>& dev_ids, int ring_id = 0);

Expand Down Expand Up @@ -274,8 +274,8 @@ class BKCLCommContext {
return comm_ctx;
}

BKCLComm* CreateBKCLComm(BKCLUniqueId* bkcl_id, int nranks, int rank,
int dev_id, int ring_id = 0);
BKCLComm* CreateComm(BKCLUniqueId* bkcl_id, int nranks, int rank, int dev_id,
int ring_id = 0);

void CreateAllBKCLComms(const std::vector<int>& dev_ids, int ring_id = 0);

Expand Down
Loading