Skip to content

Commit

Permalink
Refactor collective communication static check (#48646)
Browse files Browse the repository at this point in the history
* refactor: classify static check

* refactor: rename to static_check & use forward decl

* refactor: switch to unary & binary funcs
  • Loading branch information
HermitSun authored Dec 3, 2022
1 parent f9815bf commit 4552be4
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 147 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/collective/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ endif()
if(WITH_NCCL OR WITH_RCCL)
cc_library(
processgroup_nccl
SRCS ProcessGroupNCCL.cc NCCLTools.cc Common.cc
SRCS ProcessGroupNCCL.cc NCCLTools.cc Common.cc static_check.cc
DEPS processgroup
processgroup_stream
place
Expand Down
104 changes: 0 additions & 104 deletions paddle/fluid/distributed/collective/NCCLTools.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,109 +44,5 @@ std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID) {
return oss.str();
}

void StaticCheckTensor(const phi::DenseTensor& tensor,
int rank,
int world_size) {
// place check
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(tensor.place()),
true,
platform::errors::InvalidArgument("Tensor should be in GPU place."));
// rank check
PADDLE_ENFORCE_GE(rank,
0,
platform::errors::InvalidArgument(
"Rank should be greater than or equal to 0."));
PADDLE_ENFORCE_LT(
rank,
world_size,
platform::errors::InvalidArgument("Rank is out of the process group."));
}

// static check for collective
void StaticCheckTensors(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size,
int out_size_factor,
int in_size_factor) {
// place check
PADDLE_ENFORCE_EQ(platform::is_gpu_place(out_tensor.place()),
true,
platform::errors::InvalidArgument(
"Output tensor should be in GPU place."));
PADDLE_ENFORCE_EQ(platform::is_gpu_place(in_tensor.place()),
true,
platform::errors::InvalidArgument(
"Input tensor should be in GPU place."));
// rank check
PADDLE_ENFORCE_GE(rank,
0,
platform::errors::InvalidArgument(
"Rank should be greater than or equal to 0."));
PADDLE_ENFORCE_LT(
rank,
world_size,
platform::errors::InvalidArgument("Rank is out of the process group."));
// shape check
int64_t out_size = out_tensor.numel();
PADDLE_ENFORCE_GT(out_size,
0,
platform::errors::InvalidArgument(
"Size of output tensor should be greater than 0."));
int64_t in_size = in_tensor.numel();
PADDLE_ENFORCE_GT(in_size,
0,
platform::errors::InvalidArgument(
"Size of input tensor should be greater than 0."));
PADDLE_ENFORCE_EQ(
out_size * out_size_factor,
in_size * in_size_factor,
platform::errors::InvalidArgument(
"Input and output tensors should have matching sizes."));
// dtype check
PADDLE_ENFORCE_EQ(
out_tensor.dtype(),
in_tensor.dtype(),
platform::errors::InvalidArgument(
"Input and output tensors should have the same data type."));
}

void StaticCheckTensorsSameShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size) {
StaticCheckTensors(out_tensor,
in_tensor,
rank,
world_size,
/*out_size_factor*/ 1,
/*in_size_factor*/ 1);
}

void StaticCheckTensorsScatterLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size) {
StaticCheckTensors(out_tensor,
in_tensor,
rank,
world_size,
/*out_size_factor*/ world_size,
/*in_size_factor*/ 1);
}

void StaticCheckTensorsGatherLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size) {
StaticCheckTensors(out_tensor,
in_tensor,
rank,
world_size,
/*out_size_factor*/ 1,
/*in_size_factor*/ world_size);
}

} // namespace distributed
} // namespace paddle
27 changes: 0 additions & 27 deletions paddle/fluid/distributed/collective/NCCLTools.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,32 +63,5 @@ ncclRedOp_t ToNCCLRedType(ReduceOp reduction);

std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID);

// static check for p2p
void StaticCheckTensor(const phi::DenseTensor& tensor,
int rank,
int world_size);

// static check for collective
void StaticCheckTensors(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size,
int out_size_factor,
int in_size_factor);

void StaticCheckTensorsSameShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size);

void StaticCheckTensorsScatterLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size);

void StaticCheckTensorsGatherLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size);
} // namespace distributed
} // namespace paddle
55 changes: 40 additions & 15 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/distributed/collective/NCCLTools.h"
#include "paddle/fluid/distributed/collective/static_check.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/place.h"
Expand Down Expand Up @@ -138,8 +139,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
// numel > 0 indicates the tensor need to be sliced
const phi::DenseTensor& in_tensor_maybe_partial =
numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
StaticCheckTensorsGatherLikeShape(
*out_tensor, in_tensor_maybe_partial, rank_, size_);
CommStaticCheck::GatherLikeShape(*out_tensor,
in_tensor_maybe_partial,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclAllGather(
Expand All @@ -162,7 +166,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
StaticCheckTensorsSameShape(*out_tensor, in_tensor, rank_, size_);
CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclAllReduce(
Expand Down Expand Up @@ -214,12 +222,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
// NOTE: Since `all_to_all` needs other processes's participation, it cannot
// simply be covered by static checks. Factors are set to 0 here to skip the
// shape check. Its shape check will be done by dynamic checks in debug mode.
StaticCheckTensors(*out_tensor,
in_tensor,
rank_,
size_,
/*out_size_factor*/ 0,
/*in_size_factor*/ 0);
CommStaticCheck::CheckShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
/*out_size_factor*/ 0,
/*in_size_factor*/ 0);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
int64_t in_row_size = in_tensor.numel() / in_dim[0],
Expand Down Expand Up @@ -287,7 +296,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
StaticCheckTensorsSameShape(*out_tensor, in_tensor, rank_, size_);
CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
int root = opts.source_rank + opts.source_root;
Expand All @@ -312,7 +325,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
StaticCheckTensorsSameShape(*out_tensor, in_tensor, rank_, size_);
CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ opts.root_rank,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclReduce(
Expand All @@ -337,7 +354,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
StaticCheckTensorsScatterLikeShape(*out_tensor, in_tensor, rank_, size_);
CommStaticCheck::ScatterLikeShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclReduceScatter(
Expand All @@ -361,7 +382,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
StaticCheckTensorsScatterLikeShape(*out_tensor, in_tensor, rank_, size_);
CommStaticCheck::ScatterLikeShape(*out_tensor,
in_tensor,
/*dst_rank*/ opts.root_rank,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
int64_t numel = in_tensor.numel() / size_;
Expand Down Expand Up @@ -418,7 +443,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
tensor = &partial_tensor;
}

StaticCheckTensor(*tensor, rank_, size_);
CommStaticCheck::SingleTensor(*tensor, rank_, size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclRecv(
Expand Down Expand Up @@ -446,7 +471,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
const phi::DenseTensor& tensor_maybe_partial =
numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor;

StaticCheckTensor(tensor_maybe_partial, rank_, size_);
CommStaticCheck::SingleTensor(tensor_maybe_partial, rank_, size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclSend(
Expand Down
Loading

0 comments on commit 4552be4

Please sign in to comment.