Skip to content

Commit

Permalink
upgrade comm in global gather
Browse files Browse the repository at this point in the history
  • Loading branch information
Difers committed Sep 14, 2023
1 parent 0ade0f8 commit 14aecea
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 30 deletions.
114 changes: 86 additions & 28 deletions paddle/fluid/operators/collective/global_gather_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ limitations under the License. */
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -79,15 +84,42 @@ struct GlobalGatherFunctor<phi::GPUContext, T> {
"The ring_id (%d) for global gather op must be non-negative.",
ring_id));
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
gpuStream_t stream = nullptr;

platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
int nranks = 0;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm "
"True. But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
nranks = comm_ctx->GetSize();
} else {
comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
stream = comm->stream();
nranks = comm->nranks();
}
if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
int nranks = comm->nranks();

auto in_feat = x->dims()[1];
auto n_expert = local_count->dims()[0] / nranks;

Expand All @@ -104,34 +136,60 @@ struct GlobalGatherFunctor<phi::GPUContext, T> {
expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1];
}
auto send_ptr = 0;
auto send_buf = x->data<T>();
auto recv_buf = out->mutable_data<T>(out_dims, place);
out->mutable_data<T>(out_dims, place);

for (auto i = 0; i < n_expert; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto j = 0; j < nranks; ++j) {
int idx = i + j * n_expert;
if (cpu_global_count_data[idx]) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclSend(send_buf + send_ptr * in_feat,
cpu_global_count_data[idx] * in_feat,
dtype,
j,
comm->comm(),
stream));
send_ptr += cpu_global_count_data[idx];
if (comm_ctx) {
for (auto i = 0; i < n_expert; ++i) {
comm_ctx->GroupStart();
for (auto j = 0; j < nranks; ++j) {
int idx = i + j * n_expert;
if (cpu_global_count_data[idx]) {
auto send_buf = distributed::GetPartialTensor(
*x, send_ptr * in_feat, cpu_global_count_data[idx] * in_feat);
comm_ctx->Send(
send_buf, cpu_global_count_data[idx] * in_feat, j, stream);
send_ptr += cpu_global_count_data[idx];
}
if (cpu_local_count_data[idx]) {
auto recv_buf = distributed::GetPartialTensor(
*out,
expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat);
comm_ctx->Recv(
&recv_buf, cpu_local_count_data[idx] * in_feat, j, stream);
}
}
if (cpu_local_count_data[idx]) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclRecv(recv_buf + expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat,
dtype,
j,
comm->comm(),
stream));
comm_ctx->GroupEnd();
}
} else {
auto send_buf = x->data<T>();
auto recv_buf = out->data<T>();
for (auto i = 0; i < n_expert; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto j = 0; j < nranks; ++j) {
int idx = i + j * n_expert;
if (cpu_global_count_data[idx]) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
send_buf + send_ptr * in_feat,
cpu_global_count_data[idx] * in_feat,
dtype,
j,
comm->comm(),
stream));
send_ptr += cpu_global_count_data[idx];
}
if (cpu_local_count_data[idx]) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
recv_buf + expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat,
dtype,
j,
comm->comm(),
stream));
}
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
}
#else
PADDLE_THROW(
Expand Down
35 changes: 33 additions & 2 deletions test/collective/collective_global_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,40 @@ def get_model(self, main_prog, startup_program, rank, indata=None):

return [output]

def get_model_new_comm(self, main_prog, startup_program, rank, indata=None):
with base.program_guard(main_prog, startup_program):
seed = os.getpid()
np.random.seed(seed)
in_feat = 2
n_expert = 2
world_size = 2
tot_expert = n_expert * world_size
local_input_buf = paddle.static.data(
name="local_input_buf", shape=[-1, in_feat], dtype="float32"
)
local_expert_count = paddle.static.data(
name="local_expert_count", shape=[tot_expert], dtype="int64"
)
global_expert_count = paddle.static.data(
name="global_expert_count", shape=[tot_expert], dtype="int64"
)

output = moe_utils.global_gather(
local_input_buf, local_expert_count, global_expert_count
)

return [output]

def run_trainer(self, args):
train_prog = base.Program()
startup_prog = base.Program()
endpoints = args["endpoints"].split(",")
rank = args["trainerid"]
current_endpoint = args["currentendpoint"]
paddle.distributed.init_parallel_env()
if args["dynamic_static_unified_comm"]:
paddle.distributed.collective._init_parallel_env(args["backend"])
else:
paddle.distributed.init_parallel_env()
nranks = 2
if args['backend'] == 'nccl':
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
Expand Down Expand Up @@ -109,7 +136,11 @@ def run_trainer(self, args):
)

if args['static_mode']:
result = self.get_model(train_prog, startup_prog, rank)
result = (
self.get_model_new_comm(train_prog, startup_prog, rank)
if args["dynamic_static_unified_comm"]
else self.get_model(train_prog, startup_prog, rank)
)
fetch_list = []
for elem in result:
fetch_list.append(elem.name)
Expand Down
9 changes: 9 additions & 0 deletions test/collective/test_collective_global_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def test_global_gather_nccl_dygraph_eager(self):
eager_mode=True,
)

def test_global_gather_nccl_new_comm(self):
paddle.enable_static()
self.check_with_place(
"collective_global_gather.py",
"global_gather",
"nccl",
need_envs={"FLAGS_dynamic_static_unified_comm": "1"},
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 14aecea

Please sign in to comment.