From 5259a7cd0fa48d7937fa131203ae45cc5c532068 Mon Sep 17 00:00:00 2001 From: binbinHan Date: Sun, 21 Aug 2022 23:44:48 +0800 Subject: [PATCH] Refactor ccl send and recv (#8855) * rename REGISTER_COLLECTIVE_COMMUNICATION_FACTORY to REGISTER_COLLECTIVE_COMMUNICATION * refactor_ccl_allgather_and_reduce_scatter * refactor ccl::Reduce * remove useless code * refactor ccl::Broadcast * fix static check error * reslove comment * monir fix * reslove comments * fix macro lock error * refine * fix an idiot error * fix reduce functor bug * refactor_ccl_send_and_recv * refine * refine Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- oneflow/core/boxing/one_to_one_boxing.cpp | 4 +- oneflow/core/boxing/slice_boxing_util.cpp | 6 +- oneflow/core/ccl/ccl.cpp | 59 +------------------ oneflow/core/ccl/ccl.h | 6 +- oneflow/core/functional/impl/comm_functor.cpp | 17 ++---- .../functional/impl/slice_boxing_functor.cpp | 1 - .../cpu/cpu_broadcast.cpp | 10 ++-- .../collective_communication/cpu/cpu_recv.cpp | 51 ++++++++++++++++ .../collective_communication/cpu/cpu_send.cpp | 51 ++++++++++++++++ .../cuda/cuda_recv.cpp | 53 +++++++++++++++++ .../cuda/cuda_send.cpp | 53 +++++++++++++++++ .../cuda/cuda_send_recv_util.cpp | 43 ++++++++++++++ .../cuda/cuda_send_recv_util.h | 34 +++++++++++ .../collective_communication/include/recv.h | 44 ++++++++++++++ .../collective_communication/include/send.h | 44 ++++++++++++++ oneflow/user/kernels/communicate_util.cpp | 47 +++++++-------- oneflow/user/kernels/communicate_util.h | 17 ++++-- oneflow/user/kernels/eager_b_to_s_kernel.cpp | 23 +++----- oneflow/user/kernels/eager_nccl_kernels.cpp | 15 ++--- oneflow/user/kernels/eager_p_to_b_kernel.cpp | 32 +++++----- oneflow/user/kernels/eager_p_to_s_kernel.cpp | 34 +++++------ oneflow/user/kernels/eager_s_to_b_kernel.cpp | 23 +++----- oneflow/user/kernels/eager_s_to_p_kernel.cpp | 31 +++++----- oneflow/user/kernels/eager_s_to_s_kernel.cpp | 25 +++----- oneflow/user/kernels/p2p_comm_kernel.cpp | 48 ++++++++++----- 25 files changed, 540 insertions(+), 231 deletions(-) create mode 100644 oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp create mode 100644 oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp create mode 100644 oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp create mode 100644 oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp create mode 100644 oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.cpp create mode 100644 oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h create mode 100644 oneflow/user/kernels/collective_communication/include/recv.h create mode 100644 oneflow/user/kernels/collective_communication/include/send.h diff --git a/oneflow/core/boxing/one_to_one_boxing.cpp b/oneflow/core/boxing/one_to_one_boxing.cpp index 31e7a98c1a0..1aaafeb26c1 100644 --- a/oneflow/core/boxing/one_to_one_boxing.cpp +++ b/oneflow/core/boxing/one_to_one_boxing.cpp @@ -19,6 +19,7 @@ limitations under the License. #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/decorator.h" +#include "oneflow/user/kernels/communicate_util.h" namespace oneflow { @@ -31,8 +32,7 @@ Maybe RawCheckNaiveOneToOne(Symbol in, Symbol ou CHECK_EQ_OR_RETURN(out->placement()->parallel_num(), 1); CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag()); CHECK_OR_RETURN(in->placement() != out->placement()); - CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU - || in->placement()->device_type() == DeviceType::kCUDA); + CHECK_OR_RETURN(IsSendAndRecvRegistered(in->placement()->device_type())); // NOLINT return Maybe::Ok(); } // NOLINTEND(maybe-need-error-msg) diff --git a/oneflow/core/boxing/slice_boxing_util.cpp b/oneflow/core/boxing/slice_boxing_util.cpp index bea946177b8..81ec16ebb3a 100644 --- a/oneflow/core/boxing/slice_boxing_util.cpp +++ b/oneflow/core/boxing/slice_boxing_util.cpp @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h" #include "oneflow/core/boxing/eager_boxing_logger.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" +#include "oneflow/user/kernels/communicate_util.h" namespace oneflow { @@ -26,10 +27,7 @@ namespace private_details { Maybe PreprocessInputTensor4SliceBoxing(const std::shared_ptr& tensor, const std::string& log_prefix) { const auto& tensor_placement = JUST(tensor->parallel_desc()); - if (tensor_placement->device_type() == DeviceType::kCPU - || tensor_placement->device_type() == DeviceType::kCUDA) { - return tensor; - } + if (IsSendAndRecvRegistered(tensor_placement->device_type())) { return tensor; } const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); Symbol new_placement = JUST(ReplaceDeviceType(tensor_placement, DeviceType::kCPU)); diff --git a/oneflow/core/ccl/ccl.cpp b/oneflow/core/ccl/ccl.cpp index 24b33526b6b..018e10cf94d 100644 --- a/oneflow/core/ccl/ccl.cpp +++ b/oneflow/core/ccl/ccl.cpp @@ -80,24 +80,7 @@ Maybe CpuBroadcast(const void* in, void* out, size_t buffer_size, int64_t return Maybe::Ok(); } -#ifdef WITH_CUDA -std::pair RawGetNcclCommAndPeerNcclRank(int64_t peer_process_id) { - std::set> device_set; - const int64_t& rank = GlobalProcessCtx::Rank(); - const int64_t peer_nccl_rank = (peer_process_id > rank) ? 1 : 0; - device_set.emplace(rank, GlobalProcessCtx::LocalRank()); - device_set.emplace(peer_process_id, GlobalProcessCtx::LocalRank(peer_process_id)); - return {CHECK_NOTNULL(Singleton::Get())->GetCommForDevice(device_set), - peer_nccl_rank}; -} -auto* GetNcclCommAndPeerNcclRank = DECORATE(&RawGetNcclCommAndPeerNcclRank, ThreadLocal); -#endif - -template<> -Maybe Send(const void* in, size_t elem_cnt, DataType dtype, int64_t dst, - ep::Stream* stream) { - CHECK_OR_RETURN(IsPODDataType(dtype)); - size_t buffer_size = elem_cnt * GetSizeOfDataType(dtype); +Maybe CpuSend(const void* in, size_t buffer_size, int64_t dst) { TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData)); NaiveAsyncTransportCtx transport_ctx( transport_token, @@ -115,28 +98,7 @@ Maybe Send(const void* in, size_t elem_cnt, DataType dty return Maybe::Ok(); } -#ifdef WITH_CUDA -template<> -Maybe Send(const void* in, size_t elem_cnt, DataType dtype, int64_t dst, - ep::Stream* stream) { -#if NCCL_VERSION_CODE >= 2700 - CHECK_OR_RETURN(IsPODDataType(dtype)); - const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(dst); - OF_NCCL_CHECK_OR_RETURN(ncclSend(in, elem_cnt, GetNcclDataType(dtype), comm_and_peer_rank.second, - comm_and_peer_rank.first, - stream->As()->cuda_stream())); - return Maybe::Ok(); -#else - UNIMPLEMENTED_THEN_RETURN() << "GPU send is only supported when nccl version >= 2.7" -#endif -} -#endif - -template<> -Maybe Recv(void* out, size_t elem_cnt, DataType dtype, int64_t src, - ep::Stream* stream) { - CHECK_OR_RETURN(IsPODDataType(dtype)); - size_t buffer_size = elem_cnt * GetSizeOfDataType(dtype); +Maybe CpuRecv(void* out, size_t buffer_size, int64_t src) { TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData)); NaiveAsyncTransportCtx transport_ctx( transport_token, @@ -154,22 +116,5 @@ Maybe Recv(void* out, size_t elem_cnt, DataType dtype, i return Maybe::Ok(); } -#ifdef WITH_CUDA -template<> -Maybe Recv(void* out, size_t elem_cnt, DataType dtype, int64_t src, - ep::Stream* stream) { -#if NCCL_VERSION_CODE >= 2700 - CHECK_OR_RETURN(IsPODDataType(dtype)); - const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(src); - OF_NCCL_CHECK_OR_RETURN(ncclRecv(out, elem_cnt, GetNcclDataType(dtype), comm_and_peer_rank.second, - comm_and_peer_rank.first, - stream->As()->cuda_stream())); - return Maybe::Ok(); -#else - UNIMPLEMENTED_THEN_RETURN() << "GPU recv is only supported when nccl version >= 2.7" -#endif -} -#endif - } // namespace ccl } // namespace oneflow diff --git a/oneflow/core/ccl/ccl.h b/oneflow/core/ccl/ccl.h index c15ec14916c..c3a0ceaa352 100644 --- a/oneflow/core/ccl/ccl.h +++ b/oneflow/core/ccl/ccl.h @@ -30,11 +30,9 @@ class TransportToken; // collective communication library namespace ccl { -template -Maybe Send(const void* in, size_t elem_cnt, DataType dtype, int64_t dst, ep::Stream* stream); +Maybe CpuSend(const void* in, size_t buffer_size, int64_t dst); -template -Maybe Recv(void* out, size_t elem_cnt, DataType dtype, int64_t src, ep::Stream* stream); +Maybe CpuRecv(void* out, size_t buffer_size, int64_t src); Maybe CpuBroadcast(const void* in, void* out, size_t buffer_size, int64_t root, Symbol parallel_desc, const TransportToken& transport_token); diff --git a/oneflow/core/functional/impl/comm_functor.cpp b/oneflow/core/functional/impl/comm_functor.cpp index 7641edf59d4..6e9280d0073 100644 --- a/oneflow/core/functional/impl/comm_functor.cpp +++ b/oneflow/core/functional/impl/comm_functor.cpp @@ -336,15 +336,13 @@ class SendFunctor { JUST(attrs.SetAttr("dst_process_id", dst)); if (send_meta) { std::shared_ptr flat_shape = JUST(FlatShape::New(*x->shape())); - JUST(ccl::Send(flat_shape.get(), sizeof(*flat_shape), DataType::kChar, dst, - nullptr)); + JUST(ccl::CpuSend(flat_shape.get(), sizeof(*flat_shape), dst)); DataType dtype = x->dtype()->data_type(); - JUST(ccl::Send(&dtype, sizeof(dtype), DataType::kChar, dst, nullptr)); + JUST(ccl::CpuSend(&dtype, sizeof(dtype), dst)); DeviceType device_type = JUST(Device::GetPlacement(*JUST(x->device())))->device_type(); - JUST(ccl::Send(&device_type, sizeof(device_type), DataType::kChar, dst, - nullptr)); + JUST(ccl::CpuSend(&device_type, sizeof(device_type), dst)); } JUST(OpInterpUtil::Dispatch(*op_expr_, {x}, attrs)); return Maybe::Ok(); @@ -373,16 +371,13 @@ class RecvFunctor { } else if (!optional_shape.has_value() && !optional_dtype.has_value() && !optional_device.has_value()) { FlatShape flat_shape{}; - JUST(ccl::Recv(&flat_shape, sizeof(flat_shape), DataType::kChar, src, - nullptr)); + JUST(ccl::CpuRecv(&flat_shape, sizeof(flat_shape), src)); shape = *JUST(flat_shape.ToShape()); - JUST(ccl::Recv(&data_type, sizeof(data_type), DataType::kChar, src, - nullptr)); + JUST(ccl::CpuRecv(&data_type, sizeof(data_type), src)); DeviceType device_type = DeviceType::kInvalidDevice; - JUST(ccl::Recv(&device_type, sizeof(device_type), DataType::kChar, src, - nullptr)); + JUST(ccl::CpuRecv(&device_type, sizeof(device_type), src)); device = JUST(Device::New(*JUST(DeviceTag4DeviceType(device_type)))); } else { UNIMPLEMENTED_THEN_RETURN() << "All or none of shape, dtype and device should have value."; diff --git a/oneflow/core/functional/impl/slice_boxing_functor.cpp b/oneflow/core/functional/impl/slice_boxing_functor.cpp index 08e24a5c9de..56861f0f010 100644 --- a/oneflow/core/functional/impl/slice_boxing_functor.cpp +++ b/oneflow/core/functional/impl/slice_boxing_functor.cpp @@ -23,7 +23,6 @@ limitations under the License. #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/impl/common.h" -#include "oneflow/core/ccl/ccl.h" namespace oneflow { namespace one { diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_broadcast.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_broadcast.cpp index 95194a98da5..873ea779cbb 100644 --- a/oneflow/user/kernels/collective_communication/cpu/cpu_broadcast.cpp +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_broadcast.cpp @@ -24,16 +24,16 @@ namespace oneflow { namespace ccl { -// Use CpuBroadcastImpl to avoid name confilict +// Use CpuBroadcastImpl to avoid name conflict class CpuBroadcastImpl final : public Broadcast { public: OF_DISALLOW_COPY_AND_MOVE(CpuBroadcastImpl); - CpuBroadcastImpl() : size_of_datatype_(0) {} + CpuBroadcastImpl() : size_of_dtype_(0) {} ~CpuBroadcastImpl() = default; void Init(DataType datatype) override { CHECK(IsPODDataType(datatype)); - this->size_of_datatype_ = GetSizeOfDataType(datatype); + this->size_of_dtype_ = GetSizeOfDataType(datatype); } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, int64_t root, @@ -41,7 +41,7 @@ class CpuBroadcastImpl final : public Broadcast { const auto& cpu_communication_ctx = std::dynamic_pointer_cast(communication_ctx); CHECK(cpu_communication_ctx); - size_t buffer_size = elem_cnt * size_of_datatype_; + size_t buffer_size = elem_cnt * size_of_dtype_; const auto& transport_token = CHECK_JUST(TransportToken::NewTransportToken(kTransportTokenTypeData)); CHECK_JUST(CpuBroadcast(in, out, buffer_size, root, cpu_communication_ctx->parallel_desc(), @@ -49,7 +49,7 @@ class CpuBroadcastImpl final : public Broadcast { } private: - size_t size_of_datatype_; + size_t size_of_dtype_; }; REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, Broadcast, CpuBroadcastImpl); diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp new file mode 100644 index 00000000000..c7dc335b404 --- /dev/null +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp @@ -0,0 +1,51 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/ccl/ccl.h" +#include "oneflow/core/job/rank_group.h" +#include "oneflow/core/framework/transport_util.h" +#include "oneflow/user/kernels/collective_communication/include/recv.h" + +namespace oneflow { + +namespace ccl { + +// Use CpuRecvImpl to avoid name conflict +class CpuRecvImpl final : public Recv { + public: + OF_DISALLOW_COPY_AND_MOVE(CpuRecvImpl); + CpuRecvImpl() : size_of_dtype_(0) {} + ~CpuRecvImpl() = default; + + void Init(DataType datatype) override { + CHECK(IsPODDataType(datatype)); + this->size_of_dtype_ = GetSizeOfDataType(datatype); + } + + void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const override { + size_t buffer_size = elem_cnt * size_of_dtype_; + CHECK_JUST(CpuRecv(out, buffer_size, src)); + } + + private: + size_t size_of_dtype_; +}; + +REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, Recv, CpuRecvImpl); + +} // namespace ccl + +} // namespace oneflow diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp new file mode 100644 index 00000000000..5a93b9255c5 --- /dev/null +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp @@ -0,0 +1,51 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/ccl/ccl.h" +#include "oneflow/core/job/rank_group.h" +#include "oneflow/core/framework/transport_util.h" +#include "oneflow/user/kernels/collective_communication/include/send.h" + +namespace oneflow { + +namespace ccl { + +// Use CpuSendImpl to avoid name conflict +class CpuSendImpl final : public Send { + public: + OF_DISALLOW_COPY_AND_MOVE(CpuSendImpl); + CpuSendImpl() : size_of_dtype_(0) {} + ~CpuSendImpl() = default; + + void Init(DataType datatype) override { + CHECK(IsPODDataType(datatype)); + this->size_of_dtype_ = GetSizeOfDataType(datatype); + } + + void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const override { + size_t buffer_size = elem_cnt * size_of_dtype_; + CHECK_JUST(CpuSend(in, buffer_size, dst)); + } + + private: + size_t size_of_dtype_; +}; + +REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, Send, CpuSendImpl); + +} // namespace ccl + +} // namespace oneflow diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp new file mode 100644 index 00000000000..cc4bcfafe3f --- /dev/null +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp @@ -0,0 +1,53 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifdef WITH_CUDA +#include "oneflow/user/kernels/collective_communication/include/recv.h" +#include "oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h" +#include "oneflow/core/device/nccl_util.h" + +namespace oneflow { + +namespace ccl { + +class CudaRecv final : public Recv { + public: + OF_DISALLOW_COPY_AND_MOVE(CudaRecv); + CudaRecv() : nccl_datatype_() {} + ~CudaRecv() = default; + + void Init(DataType datatype) override { this->nccl_datatype_ = GetNcclDataType(datatype); } + + void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const override { +#if HAS_NCCL_SEND_RECV + const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(src); + OF_NCCL_CHECK(ncclRecv(out, elem_cnt, nccl_datatype_, comm_and_peer_rank.second, + comm_and_peer_rank.first, stream->As()->cuda_stream())); +#else + UNIMPLEMENTED() << "GPU recv is only supported when nccl version >= 2.7" +#endif // HAS_NCCL_SEND_RECV + } + + private: + ncclDataType_t nccl_datatype_; +}; + +REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, Recv, CudaRecv); + +} // namespace ccl + +} // namespace oneflow + +#endif // WITH_CUDA diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp new file mode 100644 index 00000000000..da7ac181252 --- /dev/null +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp @@ -0,0 +1,53 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifdef WITH_CUDA +#include "oneflow/user/kernels/collective_communication/include/send.h" +#include "oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h" +#include "oneflow/core/device/nccl_util.h" + +namespace oneflow { + +namespace ccl { + +class CudaSend final : public Send { + public: + OF_DISALLOW_COPY_AND_MOVE(CudaSend); + CudaSend() : nccl_datatype_() {} + ~CudaSend() = default; + + void Init(DataType datatype) override { this->nccl_datatype_ = GetNcclDataType(datatype); } + + void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const override { +#if HAS_NCCL_SEND_RECV + const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(dst); + OF_NCCL_CHECK(ncclSend(in, elem_cnt, nccl_datatype_, comm_and_peer_rank.second, + comm_and_peer_rank.first, stream->As()->cuda_stream())); +#else + UNIMPLEMENTED() << "GPU send is only supported when nccl version >= 2.7" +#endif // HAS_NCCL_SEND_RECV + } + + private: + ncclDataType_t nccl_datatype_; +}; + +REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, Send, CudaSend); + +} // namespace ccl + +} // namespace oneflow + +#endif // WITH_CUDA diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.cpp new file mode 100644 index 00000000000..49fb76478c4 --- /dev/null +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.cpp @@ -0,0 +1,43 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h" +#include "oneflow/core/rpc/include/global_process_ctx.h" +#include "oneflow/core/common/decorator.h" +#ifdef WITH_CUDA +#include "oneflow/core/job/eager_nccl_comm_manager.h" + +namespace oneflow { + +namespace ccl { + +std::pair RawGetNcclCommAndPeerNcclRank(int64_t peer_process_id) { + std::set> device_set; + const int64_t& rank = GlobalProcessCtx::Rank(); + const int64_t peer_nccl_rank = (peer_process_id > rank) ? 1 : 0; + device_set.emplace(rank, GlobalProcessCtx::LocalRank()); + device_set.emplace(peer_process_id, GlobalProcessCtx::LocalRank(peer_process_id)); + return {CHECK_NOTNULL(Singleton::Get())->GetCommForDevice(device_set), + peer_nccl_rank}; +} + +decltype(GetNcclCommAndPeerNcclRank) GetNcclCommAndPeerNcclRank = + DECORATE(&RawGetNcclCommAndPeerNcclRank, ThreadLocal); + +} // namespace ccl + +} // namespace oneflow + +#endif // WITH_CUDA diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h b/oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h new file mode 100644 index 00000000000..438a39390e7 --- /dev/null +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h @@ -0,0 +1,34 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CUDA_CUDA_SEND_RECV_UTIL_H_ +#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CUDA_CUDA_SEND_RECV_UTIL_H_ + +#ifdef WITH_CUDA +#include "oneflow/core/device/nccl_util.h" + +namespace oneflow { + +namespace ccl { + +extern std::pair (*GetNcclCommAndPeerNcclRank)(int64_t peer_process_i); + +} // namespace ccl + +} // namespace oneflow + +#endif // WITH_CUDA + +#endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CUDA_CUDA_SEND_RECV_UTIL_H_ diff --git a/oneflow/user/kernels/collective_communication/include/recv.h b/oneflow/user/kernels/collective_communication/include/recv.h new file mode 100644 index 00000000000..59c1aef849f --- /dev/null +++ b/oneflow/user/kernels/collective_communication/include/recv.h @@ -0,0 +1,44 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_RECVH_ +#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_RECVH_ + +#include "oneflow/user/kernels/collective_communication/include/collective_communication.h" + +namespace oneflow { + +namespace ccl { + +class Recv : public CollectiveCommunication { + public: + OF_DISALLOW_COPY_AND_MOVE(Recv); + Recv() = default; + ~Recv() override = default; + + virtual void Init(DataType dtype) = 0; + + virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const = 0; +}; + +inline bool IsRecvRegistered(DeviceType device_type) { + return IsClassRegistered(device_type); +} + +} // namespace ccl + +} // namespace oneflow + +#endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_RECVH_ diff --git a/oneflow/user/kernels/collective_communication/include/send.h b/oneflow/user/kernels/collective_communication/include/send.h new file mode 100644 index 00000000000..6658c7de292 --- /dev/null +++ b/oneflow/user/kernels/collective_communication/include/send.h @@ -0,0 +1,44 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_SEND_H_ +#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_SEND_H_ + +#include "oneflow/user/kernels/collective_communication/include/collective_communication.h" + +namespace oneflow { + +namespace ccl { + +class Send : public CollectiveCommunication { + public: + OF_DISALLOW_COPY_AND_MOVE(Send); + Send() = default; + ~Send() override = default; + + virtual void Init(DataType dtype) = 0; + + virtual void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const = 0; +}; + +inline bool IsSendRegistered(DeviceType device_type) { + return IsClassRegistered(device_type); +} + +} // namespace ccl + +} // namespace oneflow + +#endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_SEND_H_ diff --git a/oneflow/user/kernels/communicate_util.cpp b/oneflow/user/kernels/communicate_util.cpp index 082972e385b..d9795cd8587 100644 --- a/oneflow/user/kernels/communicate_util.cpp +++ b/oneflow/user/kernels/communicate_util.cpp @@ -14,13 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/communicate_util.h" -#include "oneflow/core/device/nccl_util.h" -#include "oneflow/core/common/container_util.h" -#include "oneflow/core/framework/framework.h" +#include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/kernel/new_kernel_util.h" -#include "oneflow/core/ccl/ccl.h" -#include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/control/global_process_ctx.h" +#include "oneflow/user/kernels/collective_communication/include/send.h" +#include "oneflow/user/kernels/collective_communication/include/recv.h" namespace oneflow { @@ -33,44 +31,43 @@ const void** ThreadLocalSrcDataPtr() { } // namespace -template -Maybe Send(const void* in, size_t elem_cnt, DataType dtype, int64_t dst, ep::Stream* stream) { +bool IsSendAndRecvRegistered(DeviceType device_type) { + return ccl::IsSendRegistered(device_type) && ccl::IsRecvRegistered(device_type); +} + +Maybe Send(const void* in, size_t elem_cnt, DataType dtype, int64_t dst, + DeviceType device_type, ep::Stream* stream) { if (GlobalProcessCtx::Rank() == dst) { auto** src_data_ptr = ThreadLocalSrcDataPtr(); CHECK_OR_RETURN(*src_data_ptr == nullptr); *src_data_ptr = in; } else { - JUST(ccl::Send(in, elem_cnt, dtype, dst, stream)); + std::unique_ptr send = + ccl::NewCollectiveCommunication(device_type, dtype); + send->Launch(stream, in, elem_cnt, dst); } return Maybe::Ok(); } -template -Maybe Recv(void* out, size_t elem_cnt, DataType dtype, int64_t src, ep::Stream* stream) { +Maybe Recv(void* out, size_t elem_cnt, DataType dtype, int64_t src, DeviceType device_type, + ep::Stream* stream) { if (GlobalProcessCtx::Rank() == src) { size_t buffer_size = elem_cnt * GetSizeOfDataType(dtype); auto** src_data_ptr = ThreadLocalSrcDataPtr(); const void* in = *src_data_ptr; CHECK_OR_RETURN(*src_data_ptr != nullptr); - Memcpy(stream, out, in, buffer_size); + std::unique_ptr memcpy_primitive = + ep::primitive::NewPrimitive(device_type, + ep::primitive::MemcpyKind::kDtoD); + CHECK(memcpy_primitive) << "Can not create Memcpy primitive for device type " << device_type; + memcpy_primitive->Launch(stream, out, in, buffer_size); *src_data_ptr = nullptr; } else { - JUST(ccl::Recv(out, elem_cnt, dtype, src, stream)); + std::unique_ptr recv = + ccl::NewCollectiveCommunication(device_type, dtype); + recv->Launch(stream, out, elem_cnt, src); } return Maybe::Ok(); } -template Maybe Send(const void* in, size_t elem_cnt, DataType dtype, - int64_t dst, ep::Stream* stream); - -template Maybe Recv(void* out, size_t elem_cnt, DataType dtype, int64_t src, - ep::Stream* stream); - -#if defined(WITH_CUDA) && HAS_NCCL_SEND_RECV -template Maybe Send(const void* in, size_t elem_cnt, DataType dtype, - int64_t dst, ep::Stream* stream); - -template Maybe Recv(void* out, size_t elem_cnt, DataType dtype, - int64_t src, ep::Stream* stream); -#endif } // namespace oneflow diff --git a/oneflow/user/kernels/communicate_util.h b/oneflow/user/kernels/communicate_util.h index 75f8c33f731..3db423f484c 100644 --- a/oneflow/user/kernels/communicate_util.h +++ b/oneflow/user/kernels/communicate_util.h @@ -18,18 +18,27 @@ limitations under the License. #include "oneflow/core/common/data_type.h" #include "oneflow/core/ep/include/stream.h" +#include "oneflow/core/framework/user_op_kernel_registry.h" namespace oneflow { +bool IsSendAndRecvRegistered(DeviceType device_type); + +ALWAYS_INLINE inline auto HobIsSendAndRecvRegistered() { + return hob::make_custom("HobIsSendAndRecvRegistered", [](const user_op::KernelRegContext& ctx) { + return IsSendAndRecvRegistered(ctx.device_type()); + }); +} + // Send data from in to rank dst, if cur rank equal dst, memcopy will happen. // Rank dst needs to call Recv with the same datatype and the same count from this rank. -template -Maybe Send(const void* in, size_t elem_cnt, DataType dtype, int64_t dst, ep::Stream* stream); +Maybe Send(const void* in, size_t elem_cnt, DataType dtype, int64_t dst, + DeviceType device_type, ep::Stream* stream); // Receive data from rank src into out, if cur rank equal src, memcopy will happen. // Rank src needs to call Send with the same datatype and the same count to this rank. -template -Maybe Recv(void* out, size_t elem_cnt, DataType dtype, int64_t src, ep::Stream* stream); +Maybe Recv(void* out, size_t elem_cnt, DataType dtype, int64_t src, DeviceType device_type, + ep::Stream* stream); } // namespace oneflow diff --git a/oneflow/user/kernels/eager_b_to_s_kernel.cpp b/oneflow/user/kernels/eager_b_to_s_kernel.cpp index 17259d7323f..907dc50a313 100644 --- a/oneflow/user/kernels/eager_b_to_s_kernel.cpp +++ b/oneflow/user/kernels/eager_b_to_s_kernel.cpp @@ -153,7 +153,6 @@ size_t InferEagerBToSKernelTmpBufferSize(user_op::InferContext* ctx) { } // namespace -template class EagerBToSKernel final : public user_op::OpKernel { public: EagerBToSKernel() = default; @@ -185,6 +184,8 @@ class EagerBToSKernel final : public user_op::OpKernel { CHECK_EQ(sorted_elem_cnt2in_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); CHECK_EQ(sorted_elem_cnt2out_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); + DeviceType device_type = ctx->device_type(); + for (int64_t i = 0; i < sorted_p2p_pair.size(); ++i) { const auto& p2p_pair = sorted_p2p_pair.at(i); int64_t src = p2p_pair.first; @@ -202,8 +203,8 @@ class EagerBToSKernel final : public user_op::OpKernel { const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first; const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second; tensor_slice_copier->Copy(ctx->stream(), tmp_buffer_ptr, in_ptr); - CHECK_JUST(Send(reinterpret_cast(tmp_buffer_ptr), elem_cnt, - in->data_type(), dst, ctx->stream())); + CHECK_JUST(Send(reinterpret_cast(tmp_buffer_ptr), elem_cnt, in->data_type(), + dst, device_type, ctx->stream())); } if (GlobalProcessCtx::Rank() == dst) { const auto& elem_cnt2tensor_slice_copier_pair = @@ -211,7 +212,7 @@ class EagerBToSKernel final : public user_op::OpKernel { const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first; const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second; CHECK_JUST( - Recv(tmp_buffer_ptr, elem_cnt, out->data_type(), src, ctx->stream())); + Recv(tmp_buffer_ptr, elem_cnt, out->data_type(), src, device_type, ctx->stream())); tensor_slice_copier->Copy(ctx->stream(), out_ptr, reinterpret_cast(tmp_buffer_ptr)); } @@ -220,15 +221,9 @@ class EagerBToSKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; -#define REGISTER_EAGER_B_TO_S_KERNEL(device) \ - REGISTER_USER_KERNEL("eager_b_to_s") \ - .SetCreateFn>() \ - .SetIsMatchedHob(user_op::HobDeviceType() == device) \ - .SetInferTmpSizeFn(InferEagerBToSKernelTmpBufferSize); - -REGISTER_EAGER_B_TO_S_KERNEL(DeviceType::kCPU) -#if defined(WITH_CUDA) && HAS_NCCL_SEND_RECV -REGISTER_EAGER_B_TO_S_KERNEL(DeviceType::kCUDA) -#endif +REGISTER_USER_KERNEL("eager_b_to_s") + .SetCreateFn() + .SetIsMatchedHob(HobIsSendAndRecvRegistered()) + .SetInferTmpSizeFn(InferEagerBToSKernelTmpBufferSize); } // namespace oneflow diff --git a/oneflow/user/kernels/eager_nccl_kernels.cpp b/oneflow/user/kernels/eager_nccl_kernels.cpp index 56fba121550..4272099f8e8 100644 --- a/oneflow/user/kernels/eager_nccl_kernels.cpp +++ b/oneflow/user/kernels/eager_nccl_kernels.cpp @@ -17,7 +17,6 @@ limitations under the License. #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/framework.h" -#include "oneflow/core/ccl/ccl.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/kernel/new_kernel_util.h" @@ -178,10 +177,9 @@ class EagerCclS2SKernel final : public user_op::OpKernel { int64_t parallel_id = CHECK_JUST(parallel_desc->ParallelId4MachineDeviceId(dst, device_id)); - CHECK_JUST(Send( - reinterpret_cast(reinterpret_cast(pack_to_ptr) - + parallel_id * chunk_size), - elem_per_chunk, in->data_type(), dst, ctx->stream())); + CHECK_JUST(Send(reinterpret_cast(reinterpret_cast(pack_to_ptr) + + parallel_id * chunk_size), + elem_per_chunk, in->data_type(), dst, DeviceType::kCPU, ctx->stream())); } if (GlobalProcessCtx::Rank() == dst) { Symbol parallel_desc = kernel_cache->parallel_desc(); @@ -189,10 +187,9 @@ class EagerCclS2SKernel final : public user_op::OpKernel { int64_t parallel_id = CHECK_JUST(parallel_desc->ParallelId4MachineDeviceId(src, device_id)); - CHECK_JUST(Recv( - reinterpret_cast(reinterpret_cast(unpack_from_ptr) - + parallel_id * chunk_size), - elem_per_chunk, out->data_type(), src, ctx->stream())); + CHECK_JUST(Recv(reinterpret_cast(reinterpret_cast(unpack_from_ptr) + + parallel_id * chunk_size), + elem_per_chunk, out->data_type(), src, DeviceType::kCPU, ctx->stream())); } } } diff --git a/oneflow/user/kernels/eager_p_to_b_kernel.cpp b/oneflow/user/kernels/eager_p_to_b_kernel.cpp index 0a0b8ee0ede..da6ab32b6fd 100644 --- a/oneflow/user/kernels/eager_p_to_b_kernel.cpp +++ b/oneflow/user/kernels/eager_p_to_b_kernel.cpp @@ -22,6 +22,7 @@ limitations under the License. #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/framework/placement_sbp_util.h" #include "oneflow/core/ep/include/primitive/add.h" +#include "oneflow/core/ep/include/primitive/memset.h" namespace oneflow { @@ -65,7 +66,6 @@ size_t InferEagerPToBKernelTmpBufferSize(user_op::InferContext* ctx) { } // namespace -template class EagerPToBKernel final : public user_op::OpKernel { public: EagerPToBKernel() = default; @@ -91,8 +91,14 @@ class EagerPToBKernel final : public user_op::OpKernel { const int64_t total_elem_cnt = ctx->Attr("shape").elem_cnt(); const auto& p2p_pair = kernel_cache->p2p_pair(); - Memset(ctx->stream(), out->mut_dptr(), 0, - total_elem_cnt * GetSizeOfDataType(out->data_type())); + DeviceType device_type = ctx->device_type(); + + std::unique_ptr memset_primitive = + ep::primitive::NewPrimitive(device_type); + CHECK(memset_primitive) << "Can not create Memset primitive for device type " << device_type; + memset_primitive->Launch(ctx->stream(), out->mut_dptr(), 0, + total_elem_cnt * GetSizeOfDataType(out->data_type())); + std::unique_ptr add_primitive = ep::primitive::NewPrimitive(ctx->device_type(), in->data_type()); CHECK(add_primitive); @@ -101,11 +107,11 @@ class EagerPToBKernel final : public user_op::OpKernel { int64_t dst = pair.second; if (GlobalProcessCtx::Rank() == src) { - CHECK_JUST(Send(in_ptr, total_elem_cnt, in->data_type(), dst, ctx->stream())); + CHECK_JUST(Send(in_ptr, total_elem_cnt, in->data_type(), dst, device_type, ctx->stream())); } if (GlobalProcessCtx::Rank() == dst) { - CHECK_JUST(Recv(tmp_buffer_ptr, total_elem_cnt, out->data_type(), src, - ctx->stream())); + CHECK_JUST(Recv(tmp_buffer_ptr, total_elem_cnt, out->data_type(), src, device_type, + ctx->stream())); add_primitive->Launch(ctx->stream(), out->dptr(), tmp_buffer_ptr, out->mut_dptr(), total_elem_cnt); } @@ -114,15 +120,9 @@ class EagerPToBKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; -#define REGISTER_EAGER_P_TO_B_KERNEL(device) \ - REGISTER_USER_KERNEL("eager_p_to_b") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == device)) \ - .SetInferTmpSizeFn(InferEagerPToBKernelTmpBufferSize); - -REGISTER_EAGER_P_TO_B_KERNEL(DeviceType::kCPU) -#if defined(WITH_CUDA) && HAS_NCCL_SEND_RECV -REGISTER_EAGER_P_TO_B_KERNEL(DeviceType::kCUDA) -#endif +REGISTER_USER_KERNEL("eager_p_to_b") + .SetCreateFn() + .SetIsMatchedHob(HobIsSendAndRecvRegistered()) + .SetInferTmpSizeFn(InferEagerPToBKernelTmpBufferSize); } // namespace oneflow diff --git a/oneflow/user/kernels/eager_p_to_s_kernel.cpp b/oneflow/user/kernels/eager_p_to_s_kernel.cpp index dcbd3913054..b6c1fcf0085 100644 --- a/oneflow/user/kernels/eager_p_to_s_kernel.cpp +++ b/oneflow/user/kernels/eager_p_to_s_kernel.cpp @@ -25,6 +25,7 @@ limitations under the License. #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/register/tensor_slice_copier.h" #include "oneflow/core/ep/include/primitive/add.h" +#include "oneflow/core/ep/include/primitive/memset.h" namespace oneflow { @@ -134,7 +135,6 @@ size_t InferEagerPToSKernelTmpBufferSize(user_op::InferContext* ctx) { } // namespace -template class EagerPToSKernel final : public user_op::OpKernel { public: EagerPToSKernel() = default; @@ -163,8 +163,14 @@ class EagerPToSKernel final : public user_op::OpKernel { const auto& sorted_p2p_pair = kernel_cache->sorted_p2p_pair(); CHECK_EQ(sorted_elem_cnt2_in_tensor_slice_copier.size(), sorted_p2p_pair.size()); - Memset(ctx->stream(), out->mut_dptr(), 0, - elem_cnt_of_this_chunk * GetSizeOfDataType(out->data_type())); + DeviceType device_type = ctx->device_type(); + + std::unique_ptr memset_primitive = + ep::primitive::NewPrimitive(device_type); + CHECK(memset_primitive) << "Can not create Memset primitive for device type " << device_type; + memset_primitive->Launch(ctx->stream(), out->mut_dptr(), 0, + elem_cnt_of_this_chunk * GetSizeOfDataType(out->data_type())); + std::unique_ptr add_primitive = ep::primitive::NewPrimitive(ctx->device_type(), in->data_type()); CHECK(add_primitive); @@ -176,12 +182,12 @@ class EagerPToSKernel final : public user_op::OpKernel { const auto& tensor_slice_copier = sorted_elem_cnt2_in_tensor_slice_copier.at(i).second; int64_t send_elem_cnt = sorted_elem_cnt2_in_tensor_slice_copier.at(i).first; tensor_slice_copier->Copy(ctx->stream(), tmp_buffer_ptr, in_ptr); - CHECK_JUST(Send(reinterpret_cast(tmp_buffer_ptr), send_elem_cnt, - in->data_type(), dst, ctx->stream())); + CHECK_JUST(Send(reinterpret_cast(tmp_buffer_ptr), send_elem_cnt, + in->data_type(), dst, device_type, ctx->stream())); } if (GlobalProcessCtx::Rank() == dst) { - CHECK_JUST(Recv(tmp_buffer_ptr, elem_cnt_of_this_chunk, out->data_type(), src, - ctx->stream())); + CHECK_JUST(Recv(tmp_buffer_ptr, elem_cnt_of_this_chunk, out->data_type(), src, device_type, + ctx->stream())); add_primitive->Launch(ctx->stream(), out->dptr(), tmp_buffer_ptr, out->mut_dptr(), elem_cnt_of_this_chunk); } @@ -190,15 +196,9 @@ class EagerPToSKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; -#define REGISTER_EAGER_P_TO_S_KERNEL(device) \ - REGISTER_USER_KERNEL("eager_p_to_s") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == device)) \ - .SetInferTmpSizeFn(InferEagerPToSKernelTmpBufferSize); - -REGISTER_EAGER_P_TO_S_KERNEL(DeviceType::kCPU) -#if defined(WITH_CUDA) && HAS_NCCL_SEND_RECV -REGISTER_EAGER_P_TO_S_KERNEL(DeviceType::kCUDA) -#endif +REGISTER_USER_KERNEL("eager_p_to_s") + .SetCreateFn() + .SetIsMatchedHob(HobIsSendAndRecvRegistered()) + .SetInferTmpSizeFn(InferEagerPToSKernelTmpBufferSize); } // namespace oneflow diff --git a/oneflow/user/kernels/eager_s_to_b_kernel.cpp b/oneflow/user/kernels/eager_s_to_b_kernel.cpp index a4af5eaa2f8..280e77b944c 100644 --- a/oneflow/user/kernels/eager_s_to_b_kernel.cpp +++ b/oneflow/user/kernels/eager_s_to_b_kernel.cpp @@ -135,7 +135,6 @@ size_t InferEagerSToBKernelTmpBufferSize(user_op::InferContext* ctx) { } // namespace -template class EagerSToBKernel final : public user_op::OpKernel { public: EagerSToBKernel() = default; @@ -167,6 +166,8 @@ class EagerSToBKernel final : public user_op::OpKernel { CHECK_EQ(sorted_elem_cnt2in_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); CHECK_EQ(sorted_elem_cnt2out_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); + DeviceType device_type = ctx->device_type(); + for (int64_t i = 0; i < sorted_p2p_pair.size(); ++i) { const auto& p2p_pair = sorted_p2p_pair.at(i); int64_t src = p2p_pair.first; @@ -177,8 +178,8 @@ class EagerSToBKernel final : public user_op::OpKernel { const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first; const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second; tensor_slice_copier->Copy(ctx->stream(), tmp_buffer_ptr, in_ptr); - CHECK_JUST(Send(reinterpret_cast(tmp_buffer_ptr), elem_cnt, - in->data_type(), dst, ctx->stream())); + CHECK_JUST(Send(reinterpret_cast(tmp_buffer_ptr), elem_cnt, in->data_type(), + dst, device_type, ctx->stream())); } if (GlobalProcessCtx::Rank() == dst) { const auto& elem_cnt2tensor_slice_copier_pair = @@ -186,7 +187,7 @@ class EagerSToBKernel final : public user_op::OpKernel { const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first; const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second; CHECK_JUST( - Recv(tmp_buffer_ptr, elem_cnt, out->data_type(), src, ctx->stream())); + Recv(tmp_buffer_ptr, elem_cnt, out->data_type(), src, device_type, ctx->stream())); tensor_slice_copier->Copy(ctx->stream(), out_ptr, reinterpret_cast(tmp_buffer_ptr)); } @@ -195,15 +196,9 @@ class EagerSToBKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; -#define REGISTER_EAGER_S_TO_B_KERNEL(device) \ - REGISTER_USER_KERNEL("eager_s_to_b") \ - .SetCreateFn>() \ - .SetIsMatchedHob(user_op::HobDeviceType() == device) \ - .SetInferTmpSizeFn(InferEagerSToBKernelTmpBufferSize); - -REGISTER_EAGER_S_TO_B_KERNEL(DeviceType::kCPU) -#if defined(WITH_CUDA) && HAS_NCCL_SEND_RECV -REGISTER_EAGER_S_TO_B_KERNEL(DeviceType::kCUDA) -#endif +REGISTER_USER_KERNEL("eager_s_to_b") + .SetCreateFn() + .SetIsMatchedHob(HobIsSendAndRecvRegistered()) + .SetInferTmpSizeFn(InferEagerSToBKernelTmpBufferSize); } // namespace oneflow diff --git a/oneflow/user/kernels/eager_s_to_p_kernel.cpp b/oneflow/user/kernels/eager_s_to_p_kernel.cpp index 5e076f2df3c..a65e94e9093 100644 --- a/oneflow/user/kernels/eager_s_to_p_kernel.cpp +++ b/oneflow/user/kernels/eager_s_to_p_kernel.cpp @@ -152,7 +152,6 @@ size_t InferEagerSToPKernelTmpBufferSize(user_op::InferContext* ctx) { } // namespace -template class EagerSToPKernel final : public user_op::OpKernel { public: EagerSToPKernel() = default; @@ -177,8 +176,14 @@ class EagerSToPKernel final : public user_op::OpKernel { void* tmp_buffer_ptr = tmp_buffer->mut_dptr(); const int64_t total_elem_cnt = ctx->Attr("shape").elem_cnt(); - Memset(ctx->stream(), out->mut_dptr(), 0, - total_elem_cnt * GetSizeOfDataType(out->data_type())); + + DeviceType device_type = ctx->device_type(); + + std::unique_ptr memset_primitive = + ep::primitive::NewPrimitive(device_type); + CHECK(memset_primitive) << "Can not create Memset primitive for device type " << device_type; + memset_primitive->Launch(ctx->stream(), out->mut_dptr(), 0, + total_elem_cnt * GetSizeOfDataType(out->data_type())); const auto& sorted_elem_cnt2in_tensor_slice_copier_pair = kernel_cache->sorted_elem_cnt2in_tensor_slice_copier_pair(); @@ -205,8 +210,8 @@ class EagerSToPKernel final : public user_op::OpKernel { const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first; const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second; tensor_slice_copier->Copy(ctx->stream(), tmp_buffer_ptr, in_ptr); - CHECK_JUST(Send(reinterpret_cast(tmp_buffer_ptr), elem_cnt, - in->data_type(), dst, ctx->stream())); + CHECK_JUST(Send(reinterpret_cast(tmp_buffer_ptr), elem_cnt, in->data_type(), + dst, device_type, ctx->stream())); } if (GlobalProcessCtx::Rank() == dst) { const auto& elem_cnt2tensor_slice_copier_pair = @@ -214,7 +219,7 @@ class EagerSToPKernel final : public user_op::OpKernel { const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first; const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second; CHECK_JUST( - Recv(tmp_buffer_ptr, elem_cnt, out->data_type(), src, ctx->stream())); + Recv(tmp_buffer_ptr, elem_cnt, out->data_type(), src, device_type, ctx->stream())); tensor_slice_copier->Copy(ctx->stream(), out_ptr, reinterpret_cast(tmp_buffer_ptr)); } @@ -223,15 +228,9 @@ class EagerSToPKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; -#define REGISTER_EAGER_S_TO_B_KERNEL(device) \ - REGISTER_USER_KERNEL("eager_s_to_p") \ - .SetCreateFn>() \ - .SetIsMatchedHob(user_op::HobDeviceType() == device) \ - .SetInferTmpSizeFn(InferEagerSToPKernelTmpBufferSize); - -REGISTER_EAGER_S_TO_B_KERNEL(DeviceType::kCPU) -#if defined(WITH_CUDA) && HAS_NCCL_SEND_RECV -REGISTER_EAGER_S_TO_B_KERNEL(DeviceType::kCUDA) -#endif +REGISTER_USER_KERNEL("eager_s_to_p") + .SetCreateFn() + .SetIsMatchedHob(HobIsSendAndRecvRegistered()) + .SetInferTmpSizeFn(InferEagerSToPKernelTmpBufferSize); } // namespace oneflow diff --git a/oneflow/user/kernels/eager_s_to_s_kernel.cpp b/oneflow/user/kernels/eager_s_to_s_kernel.cpp index 452890b4e43..b9a14d03e77 100644 --- a/oneflow/user/kernels/eager_s_to_s_kernel.cpp +++ b/oneflow/user/kernels/eager_s_to_s_kernel.cpp @@ -19,7 +19,6 @@ limitations under the License. #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" -#include "oneflow/core/ccl/ccl.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/register/tensor_slice_copier.h" @@ -136,7 +135,6 @@ size_t InferNaiveSToSKernelTmpBufferSize(user_op::InferContext* ctx) { } // namespace -template class EagerNaiveSToSKernel final : public user_op::OpKernel { public: EagerNaiveSToSKernel() = default; @@ -168,6 +166,8 @@ class EagerNaiveSToSKernel final : public user_op::OpKernel { CHECK_EQ(sorted_elem_cnt2in_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); CHECK_EQ(sorted_elem_cnt2out_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); + DeviceType device_type = ctx->device_type(); + for (int64_t i = 0; i < sorted_p2p_pair.size(); ++i) { const auto& p2p_pair = sorted_p2p_pair.at(i); int64_t src = p2p_pair.first; @@ -178,8 +178,8 @@ class EagerNaiveSToSKernel final : public user_op::OpKernel { const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first; const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second; tensor_slice_copier->Copy(ctx->stream(), tmp_buffer_ptr, in_ptr); - CHECK_JUST(Send(reinterpret_cast(tmp_buffer_ptr), elem_cnt, - in->data_type(), dst, ctx->stream())); + CHECK_JUST(Send(reinterpret_cast(tmp_buffer_ptr), elem_cnt, in->data_type(), + dst, device_type, ctx->stream())); } if (GlobalProcessCtx::Rank() == dst) { const auto& elem_cnt2tensor_slice_copier_pair = @@ -187,7 +187,7 @@ class EagerNaiveSToSKernel final : public user_op::OpKernel { const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first; const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second; CHECK_JUST( - Recv(tmp_buffer_ptr, elem_cnt, out->data_type(), src, ctx->stream())); + Recv(tmp_buffer_ptr, elem_cnt, out->data_type(), src, device_type, ctx->stream())); tensor_slice_copier->Copy(ctx->stream(), out_ptr, reinterpret_cast(tmp_buffer_ptr)); } @@ -196,16 +196,9 @@ class EagerNaiveSToSKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; -#define REGISTER_EAGER_NAIVE_S_TO_S_KERNEL(device) \ - REGISTER_USER_KERNEL("eager_naive_s_to_s") \ - .SetCreateFn>() \ - .SetIsMatchedHob(user_op::HobDeviceType() == device) \ - .SetInferTmpSizeFn(InferNaiveSToSKernelTmpBufferSize); - -REGISTER_EAGER_NAIVE_S_TO_S_KERNEL(DeviceType::kCPU) - -#if defined(WITH_CUDA) && HAS_NCCL_SEND_RECV -REGISTER_EAGER_NAIVE_S_TO_S_KERNEL(DeviceType::kCUDA) -#endif +REGISTER_USER_KERNEL("eager_naive_s_to_s") + .SetCreateFn() + .SetIsMatchedHob(HobIsSendAndRecvRegistered()) + .SetInferTmpSizeFn(InferNaiveSToSKernelTmpBufferSize); } // namespace oneflow diff --git a/oneflow/user/kernels/p2p_comm_kernel.cpp b/oneflow/user/kernels/p2p_comm_kernel.cpp index 0e21933147d..984ca798873 100644 --- a/oneflow/user/kernels/p2p_comm_kernel.cpp +++ b/oneflow/user/kernels/p2p_comm_kernel.cpp @@ -15,16 +15,36 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" -#include "oneflow/core/ccl/ccl.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/rank_group.h" #include "oneflow/core/framework/instructions_builder.h" +#include "oneflow/user/kernels/collective_communication/include/send.h" +#include "oneflow/user/kernels/collective_communication/include/recv.h" namespace oneflow { namespace { -template +namespace { + +auto SendCollectiveCommunicationExists() { + return hob::make_custom("SendCollectiveCommunicationExists", + [=](const user_op::KernelRegContext& ctx) { + DeviceType device_type = ctx.device_type(); + return ccl::IsSendRegistered(device_type); + }); +} + +auto RecvCollectiveCommunicationExists() { + return hob::make_custom("RecvCollectiveCommunicationExists", + [=](const user_op::KernelRegContext& ctx) { + DeviceType device_type = ctx.device_type(); + return ccl::IsRecvRegistered(device_type); + }); +} + +} // namespace + class SendKernel final : public user_op::OpKernel { public: SendKernel() = default; @@ -34,13 +54,13 @@ class SendKernel final : public user_op::OpKernel { void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const auto& dst_process_id = ctx->Attr("dst_process_id"); - CHECK_JUST(ccl::Send(in->dptr(), in->shape_view().elem_cnt(), in->data_type(), - dst_process_id, ctx->stream())); + std::unique_ptr send = + ccl::NewCollectiveCommunication(ctx->device_type(), in->data_type()); + send->Launch(ctx->stream(), in->dptr(), in->shape_view().elem_cnt(), dst_process_id); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; -template class RecvKernel final : public user_op::OpKernel { public: RecvKernel() = default; @@ -50,22 +70,18 @@ class RecvKernel final : public user_op::OpKernel { void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const auto& src_process_id = ctx->Attr("src_process_id"); - CHECK_JUST(ccl::Recv(out->mut_dptr(), out->shape_view().elem_cnt(), - out->data_type(), src_process_id, ctx->stream())); + std::unique_ptr recv = + ccl::NewCollectiveCommunication(ctx->device_type(), out->data_type()); + recv->Launch(ctx->stream(), out->mut_dptr(), out->shape_view().elem_cnt(), src_process_id); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; -#define REGISTER_KERNEL(device) \ - REGISTER_USER_KERNEL("send").SetCreateFn>().SetIsMatchedHob( \ - (user_op::HobDeviceType() == device)); \ - REGISTER_USER_KERNEL("recv").SetCreateFn>().SetIsMatchedHob( \ - (user_op::HobDeviceType() == device)); +REGISTER_USER_KERNEL("send").SetCreateFn().SetIsMatchedHob( + SendCollectiveCommunicationExists()); -REGISTER_KERNEL(DeviceType::kCPU) -#ifdef WITH_CUDA -REGISTER_KERNEL(DeviceType::kCUDA) -#endif +REGISTER_USER_KERNEL("recv").SetCreateFn().SetIsMatchedHob( + RecvCollectiveCommunicationExists()); } // namespace } // namespace oneflow