Skip to content

Commit

Permalink
Refactor ccl send and recv (#8855)
Browse files Browse the repository at this point in the history
* rename REGISTER_COLLECTIVE_COMMUNICATION_FACTORY to REGISTER_COLLECTIVE_COMMUNICATION

* refactor_ccl_allgather_and_reduce_scatter

* refactor ccl::Reduce

* remove useless code

* refactor ccl::Broadcast

* fix static check error

* reslove comment

* monir fix

* reslove comments

* fix macro lock error

* refine

* fix an idiot error

* fix reduce functor bug

* refactor_ccl_send_and_recv

* refine

* refine

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
clackhan and mergify[bot] authored Aug 21, 2022
1 parent e3a9b89 commit 5259a7c
Show file tree
Hide file tree
Showing 25 changed files with 540 additions and 231 deletions.
4 changes: 2 additions & 2 deletions oneflow/core/boxing/one_to_one_boxing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/user/kernels/communicate_util.h"

namespace oneflow {

Expand All @@ -31,8 +32,7 @@ Maybe<void> RawCheckNaiveOneToOne(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> ou
CHECK_EQ_OR_RETURN(out->placement()->parallel_num(), 1);
CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag());
CHECK_OR_RETURN(in->placement() != out->placement());
CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU
|| in->placement()->device_type() == DeviceType::kCUDA);
CHECK_OR_RETURN(IsSendAndRecvRegistered(in->placement()->device_type())); // NOLINT
return Maybe<void>::Ok();
}
// NOLINTEND(maybe-need-error-msg)
Expand Down
6 changes: 2 additions & 4 deletions oneflow/core/boxing/slice_boxing_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h"
#include "oneflow/core/boxing/eager_boxing_logger.h"
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/user/kernels/communicate_util.h"

namespace oneflow {

Expand All @@ -26,10 +27,7 @@ namespace private_details {
Maybe<one::Tensor> PreprocessInputTensor4SliceBoxing(const std::shared_ptr<one::Tensor>& tensor,
const std::string& log_prefix) {
const auto& tensor_placement = JUST(tensor->parallel_desc());
if (tensor_placement->device_type() == DeviceType::kCPU
|| tensor_placement->device_type() == DeviceType::kCUDA) {
return tensor;
}
if (IsSendAndRecvRegistered(tensor_placement->device_type())) { return tensor; }

const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
Symbol<ParallelDesc> new_placement = JUST(ReplaceDeviceType(tensor_placement, DeviceType::kCPU));
Expand Down
59 changes: 2 additions & 57 deletions oneflow/core/ccl/ccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,24 +80,7 @@ Maybe<void> CpuBroadcast(const void* in, void* out, size_t buffer_size, int64_t
return Maybe<void>::Ok();
}

#ifdef WITH_CUDA
std::pair<ncclComm_t, int64_t> RawGetNcclCommAndPeerNcclRank(int64_t peer_process_id) {
std::set<std::pair<int64_t, int64_t>> device_set;
const int64_t& rank = GlobalProcessCtx::Rank();
const int64_t peer_nccl_rank = (peer_process_id > rank) ? 1 : 0;
device_set.emplace(rank, GlobalProcessCtx::LocalRank());
device_set.emplace(peer_process_id, GlobalProcessCtx::LocalRank(peer_process_id));
return {CHECK_NOTNULL(Singleton<EagerNcclCommMgr>::Get())->GetCommForDevice(device_set),
peer_nccl_rank};
}
auto* GetNcclCommAndPeerNcclRank = DECORATE(&RawGetNcclCommAndPeerNcclRank, ThreadLocal);
#endif

template<>
Maybe<void> Send<DeviceType::kCPU>(const void* in, size_t elem_cnt, DataType dtype, int64_t dst,
ep::Stream* stream) {
CHECK_OR_RETURN(IsPODDataType(dtype));
size_t buffer_size = elem_cnt * GetSizeOfDataType(dtype);
Maybe<void> CpuSend(const void* in, size_t buffer_size, int64_t dst) {
TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
NaiveAsyncTransportCtx transport_ctx(
transport_token,
Expand All @@ -115,28 +98,7 @@ Maybe<void> Send<DeviceType::kCPU>(const void* in, size_t elem_cnt, DataType dty
return Maybe<void>::Ok();
}

#ifdef WITH_CUDA
template<>
Maybe<void> Send<DeviceType::kCUDA>(const void* in, size_t elem_cnt, DataType dtype, int64_t dst,
ep::Stream* stream) {
#if NCCL_VERSION_CODE >= 2700
CHECK_OR_RETURN(IsPODDataType(dtype));
const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(dst);
OF_NCCL_CHECK_OR_RETURN(ncclSend(in, elem_cnt, GetNcclDataType(dtype), comm_and_peer_rank.second,
comm_and_peer_rank.first,
stream->As<ep::CudaStream>()->cuda_stream()));
return Maybe<void>::Ok();
#else
UNIMPLEMENTED_THEN_RETURN() << "GPU send is only supported when nccl version >= 2.7"
#endif
}
#endif

template<>
Maybe<void> Recv<DeviceType::kCPU>(void* out, size_t elem_cnt, DataType dtype, int64_t src,
ep::Stream* stream) {
CHECK_OR_RETURN(IsPODDataType(dtype));
size_t buffer_size = elem_cnt * GetSizeOfDataType(dtype);
Maybe<void> CpuRecv(void* out, size_t buffer_size, int64_t src) {
TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
NaiveAsyncTransportCtx transport_ctx(
transport_token,
Expand All @@ -154,22 +116,5 @@ Maybe<void> Recv<DeviceType::kCPU>(void* out, size_t elem_cnt, DataType dtype, i
return Maybe<void>::Ok();
}

#ifdef WITH_CUDA
template<>
Maybe<void> Recv<DeviceType::kCUDA>(void* out, size_t elem_cnt, DataType dtype, int64_t src,
ep::Stream* stream) {
#if NCCL_VERSION_CODE >= 2700
CHECK_OR_RETURN(IsPODDataType(dtype));
const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(src);
OF_NCCL_CHECK_OR_RETURN(ncclRecv(out, elem_cnt, GetNcclDataType(dtype), comm_and_peer_rank.second,
comm_and_peer_rank.first,
stream->As<ep::CudaStream>()->cuda_stream()));
return Maybe<void>::Ok();
#else
UNIMPLEMENTED_THEN_RETURN() << "GPU recv is only supported when nccl version >= 2.7"
#endif
}
#endif

} // namespace ccl
} // namespace oneflow
6 changes: 2 additions & 4 deletions oneflow/core/ccl/ccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@ class TransportToken;
// collective communication library
namespace ccl {

template<DeviceType device_type>
Maybe<void> Send(const void* in, size_t elem_cnt, DataType dtype, int64_t dst, ep::Stream* stream);
Maybe<void> CpuSend(const void* in, size_t buffer_size, int64_t dst);

template<DeviceType device_type>
Maybe<void> Recv(void* out, size_t elem_cnt, DataType dtype, int64_t src, ep::Stream* stream);
Maybe<void> CpuRecv(void* out, size_t buffer_size, int64_t src);

Maybe<void> CpuBroadcast(const void* in, void* out, size_t buffer_size, int64_t root,
Symbol<ParallelDesc> parallel_desc, const TransportToken& transport_token);
Expand Down
17 changes: 6 additions & 11 deletions oneflow/core/functional/impl/comm_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,13 @@ class SendFunctor {
JUST(attrs.SetAttr<int64_t>("dst_process_id", dst));
if (send_meta) {
std::shared_ptr<FlatShape> flat_shape = JUST(FlatShape::New(*x->shape()));
JUST(ccl::Send<DeviceType::kCPU>(flat_shape.get(), sizeof(*flat_shape), DataType::kChar, dst,
nullptr));
JUST(ccl::CpuSend(flat_shape.get(), sizeof(*flat_shape), dst));

DataType dtype = x->dtype()->data_type();
JUST(ccl::Send<DeviceType::kCPU>(&dtype, sizeof(dtype), DataType::kChar, dst, nullptr));
JUST(ccl::CpuSend(&dtype, sizeof(dtype), dst));

DeviceType device_type = JUST(Device::GetPlacement(*JUST(x->device())))->device_type();
JUST(ccl::Send<DeviceType::kCPU>(&device_type, sizeof(device_type), DataType::kChar, dst,
nullptr));
JUST(ccl::CpuSend(&device_type, sizeof(device_type), dst));
}
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_expr_, {x}, attrs));
return Maybe<void>::Ok();
Expand Down Expand Up @@ -373,16 +371,13 @@ class RecvFunctor {
} else if (!optional_shape.has_value() && !optional_dtype.has_value()
&& !optional_device.has_value()) {
FlatShape flat_shape{};
JUST(ccl::Recv<DeviceType::kCPU>(&flat_shape, sizeof(flat_shape), DataType::kChar, src,
nullptr));
JUST(ccl::CpuRecv(&flat_shape, sizeof(flat_shape), src));
shape = *JUST(flat_shape.ToShape());

JUST(ccl::Recv<DeviceType::kCPU>(&data_type, sizeof(data_type), DataType::kChar, src,
nullptr));
JUST(ccl::CpuRecv(&data_type, sizeof(data_type), src));

DeviceType device_type = DeviceType::kInvalidDevice;
JUST(ccl::Recv<DeviceType::kCPU>(&device_type, sizeof(device_type), DataType::kChar, src,
nullptr));
JUST(ccl::CpuRecv(&device_type, sizeof(device_type), src));
device = JUST(Device::New(*JUST(DeviceTag4DeviceType(device_type))));
} else {
UNIMPLEMENTED_THEN_RETURN() << "All or none of shape, dtype and device should have value.";
Expand Down
1 change: 0 additions & 1 deletion oneflow/core/functional/impl/slice_boxing_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ limitations under the License.
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/function_library.h"
#include "oneflow/core/functional/impl/common.h"
#include "oneflow/core/ccl/ccl.h"

namespace oneflow {
namespace one {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,32 @@ namespace oneflow {

namespace ccl {

// Use CpuBroadcastImpl to avoid name confilict
// Use CpuBroadcastImpl to avoid name conflict
class CpuBroadcastImpl final : public Broadcast {
public:
OF_DISALLOW_COPY_AND_MOVE(CpuBroadcastImpl);
CpuBroadcastImpl() : size_of_datatype_(0) {}
CpuBroadcastImpl() : size_of_dtype_(0) {}
~CpuBroadcastImpl() = default;

void Init(DataType datatype) override {
CHECK(IsPODDataType(datatype));
this->size_of_datatype_ = GetSizeOfDataType(datatype);
this->size_of_dtype_ = GetSizeOfDataType(datatype);
}

void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, int64_t root,
const std::shared_ptr<CommunicationContext>& communication_ctx) const override {
const auto& cpu_communication_ctx =
std::dynamic_pointer_cast<CpuCommunicationContext>(communication_ctx);
CHECK(cpu_communication_ctx);
size_t buffer_size = elem_cnt * size_of_datatype_;
size_t buffer_size = elem_cnt * size_of_dtype_;
const auto& transport_token =
CHECK_JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
CHECK_JUST(CpuBroadcast(in, out, buffer_size, root, cpu_communication_ctx->parallel_desc(),
transport_token));
}

private:
size_t size_of_datatype_;
size_t size_of_dtype_;
};

REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, Broadcast, CpuBroadcastImpl);
Expand Down
51 changes: 51 additions & 0 deletions oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/ccl/ccl.h"
#include "oneflow/core/job/rank_group.h"
#include "oneflow/core/framework/transport_util.h"
#include "oneflow/user/kernels/collective_communication/include/recv.h"

namespace oneflow {

namespace ccl {

// Use CpuRecvImpl to avoid name conflict
class CpuRecvImpl final : public Recv {
public:
OF_DISALLOW_COPY_AND_MOVE(CpuRecvImpl);
CpuRecvImpl() : size_of_dtype_(0) {}
~CpuRecvImpl() = default;

void Init(DataType datatype) override {
CHECK(IsPODDataType(datatype));
this->size_of_dtype_ = GetSizeOfDataType(datatype);
}

void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const override {
size_t buffer_size = elem_cnt * size_of_dtype_;
CHECK_JUST(CpuRecv(out, buffer_size, src));
}

private:
size_t size_of_dtype_;
};

REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, Recv, CpuRecvImpl);

} // namespace ccl

} // namespace oneflow
51 changes: 51 additions & 0 deletions oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/ccl/ccl.h"
#include "oneflow/core/job/rank_group.h"
#include "oneflow/core/framework/transport_util.h"
#include "oneflow/user/kernels/collective_communication/include/send.h"

namespace oneflow {

namespace ccl {

// Use CpuSendImpl to avoid name conflict
class CpuSendImpl final : public Send {
public:
OF_DISALLOW_COPY_AND_MOVE(CpuSendImpl);
CpuSendImpl() : size_of_dtype_(0) {}
~CpuSendImpl() = default;

void Init(DataType datatype) override {
CHECK(IsPODDataType(datatype));
this->size_of_dtype_ = GetSizeOfDataType(datatype);
}

void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const override {
size_t buffer_size = elem_cnt * size_of_dtype_;
CHECK_JUST(CpuSend(in, buffer_size, dst));
}

private:
size_t size_of_dtype_;
};

REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, Send, CpuSendImpl);

} // namespace ccl

} // namespace oneflow
53 changes: 53 additions & 0 deletions oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_CUDA
#include "oneflow/user/kernels/collective_communication/include/recv.h"
#include "oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h"
#include "oneflow/core/device/nccl_util.h"

namespace oneflow {

namespace ccl {

class CudaRecv final : public Recv {
public:
OF_DISALLOW_COPY_AND_MOVE(CudaRecv);
CudaRecv() : nccl_datatype_() {}
~CudaRecv() = default;

void Init(DataType datatype) override { this->nccl_datatype_ = GetNcclDataType(datatype); }

void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const override {
#if HAS_NCCL_SEND_RECV
const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(src);
OF_NCCL_CHECK(ncclRecv(out, elem_cnt, nccl_datatype_, comm_and_peer_rank.second,
comm_and_peer_rank.first, stream->As<ep::CudaStream>()->cuda_stream()));
#else
UNIMPLEMENTED() << "GPU recv is only supported when nccl version >= 2.7"
#endif // HAS_NCCL_SEND_RECV
}

private:
ncclDataType_t nccl_datatype_;
};

REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, Recv, CudaRecv);

} // namespace ccl

} // namespace oneflow

#endif // WITH_CUDA
Loading

0 comments on commit 5259a7c

Please sign in to comment.