Skip to content

Commit

Permalink
[NPU] Added HCCL backend support in dygraph mode (#36285)
Browse files Browse the repository at this point in the history
* Added HCCL backend support in dynamic graph mode

* fix segmentation fault

* add ut
  • Loading branch information
ronny1996 authored Nov 23, 2021
1 parent e58ac12 commit 83e55cf
Show file tree
Hide file tree
Showing 16 changed files with 585 additions and 11 deletions.
6 changes: 5 additions & 1 deletion paddle/fluid/imperative/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ if(NOT WIN32)
cc_library(bkcl_context SRCS bkcl_context.cc DEPS collective_helper device_context tensor var_type_traits)
cc_library(reducer SRCS reducer.cc DEPS layer)
endif()
if(WITH_ASCEND_CL)
cc_library(hccl_context SRCS hccl_context.cc DEPS collective_helper device_context tensor var_type_traits)
cc_library(reducer SRCS reducer.cc DEPS layer)
endif()
cc_library(data_loader SRCS data_loader.cc DEPS enforce)
endif(NOT WIN32)
if(WITH_GLOO)
cc_library(imperative_gloo_context SRCS gloo_context.cc DEPS collective_helper device_context tensor var_type_traits)
if ( WIN32 OR (NOT (WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL) ))
if ( WIN32 OR (NOT (WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL OR WITH_ASCEND_CL) ))
cc_library(reducer SRCS reducer.cc DEPS layer)
endif()
endif()
Expand Down
218 changes: 218 additions & 0 deletions paddle/fluid/imperative/hccl_context.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/hccl_context.h"

#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"

#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h"

#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"

namespace paddle {
namespace framework {
class Variable;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace imperative {

static void AllReduce(const framework::Tensor &src, framework::Tensor *dst,
const aclrtStream stream,
const platform::HCCLComm *comm) {
const auto &place = src.place();
PADDLE_ENFORCE_EQ(
platform::is_npu_place(place), true,
platform::errors::Unimplemented(
"Imperative mode does not support multi-CPU training yet."));

void *src_ptr = const_cast<void *>(src.data<void>());
dst->Resize(src.dims());
void *dst_ptr = dst->mutable_data(src.place(), src.type());
HcclDataType hccl_dtype = platform::ToHCCLDataType(src.type());

PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclAllReduce(
src_ptr, dst_ptr, src.numel(), hccl_dtype, HCCL_REDUCE_SUM, comm->comm(),
reinterpret_cast<void *>(stream)));
}

void HCCLParallelContext::BcastHCCLId(
std::vector<HcclRootInfo> &hccl_ids, // NOLINT
int root, int server_fd) {
if (strategy_.local_rank_ == root) {
std::vector<std::string> other_trainers;
for (auto &ep : strategy_.trainer_endpoints_) {
if (ep != strategy_.current_endpoint_) {
other_trainers.push_back(ep);
}
}
platform::SendBroadCastCommID(other_trainers, &hccl_ids);
} else {
platform::RecvBroadCastCommID(server_fd, strategy_.current_endpoint_,
&hccl_ids);
}
}

void HCCLParallelContext::Init() {
int server_fd = -1;

std::vector<HcclRootInfo> hccl_ids;
hccl_ids.resize(strategy_.nrings_);

if (strategy_.local_rank_ == 0) {
// generate the unique hcclid on the root worker
for (size_t i = 0; i < hccl_ids.size(); ++i) {
platform::dynload::HcclGetRootInfo(&hccl_ids[i]);
}
} else {
server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_)
.socket();
}
BcastHCCLId(hccl_ids, 0, server_fd);

int npu_id = BOOST_GET_CONST(platform::NPUPlace, place_).device;
for (int ring_id = 0; ring_id < strategy_.nrings_; ring_id++) {
VLOG(0) << "init hccl context nranks: " << strategy_.nranks_
<< " local rank: " << strategy_.local_rank_ << " npu id: " << npu_id
<< " ring id: " << ring_id;
// it will assign hccl_comm in NPUDeviceContext within ring_id
platform::HCCLCommContext::Instance().CreateHCCLComm(
&hccl_ids[ring_id], strategy_.nranks_, strategy_.local_rank_, npu_id,
ring_id);

compute_events_.emplace_back(platform::NpuEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::NPUPlace, place_).device));
comm_events_.emplace_back(platform::NpuEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::NPUPlace, place_).device));
}
}

void HCCLParallelContext::InitWithRingID(int ring_id) {
int server_fd = -1;
std::vector<HcclRootInfo> hccl_ids;
hccl_ids.resize(1);

if (strategy_.local_rank_ == 0) {
// generate the unique hcclid on the root worker
platform::dynload::HcclGetRootInfo(&hccl_ids[0]);
} else {
server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_)
.socket();
}
BcastHCCLId(hccl_ids, 0, server_fd);

int npu_id = BOOST_GET_CONST(platform::NPUPlace, place_).device;
VLOG(0) << "init hccl context nranks: " << strategy_.nranks_
<< " local rank: " << strategy_.local_rank_ << " npu id: " << npu_id
<< " ring id: " << ring_id;
// it will assign hccl_comm in NPUDeviceContext within ring_id
platform::HCCLCommContext::Instance().CreateHCCLComm(
&hccl_ids[0], strategy_.nranks_, strategy_.local_rank_, npu_id, ring_id);

compute_events_.emplace_back(platform::NpuEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::NPUPlace, place_).device));
comm_events_.emplace_back(platform::NpuEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::NPUPlace, place_).device));
}

void HCCLParallelContext::AllReduceByStream(const framework::Variable &src,
framework::Variable *dst,
int ring_id, bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
platform::is_npu_place(place_), true,
platform::errors::Unimplemented(
"Dynamic graph mode does not support multi-CPU training yet."));
auto *dev_ctx = static_cast<platform::NPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_));
platform::HCCLComm *comm =
platform::HCCLCommContext::Instance().Get(ring_id, place_);
aclrtStream stream = use_calc_stream ? dev_ctx->stream() : comm->stream();

if (src.IsType<framework::LoDTensor>()) {
if (!dst->IsType<framework::LoDTensor>()) {
dst->Clear();
}
AllReduce(src.Get<framework::LoDTensor>(),
dst->GetMutable<framework::LoDTensor>(), stream, comm);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"XPU unsupported variable type %s for imperative allreduce, only "
"LoDTensor are supported.",
platform::demangle(framework::ToTypeName(src.Type()))));
}
}

paddle::platform::DeviceContext *HCCLParallelContext::GetDeviceContext(
int ring_id) {
return static_cast<platform::DeviceContext *>(
platform::HCCLCommContext::Instance()
.Get(ring_id, place_)
->dev_context());
}

void HCCLParallelContext::WaitCompute(int ring_id) {
PADDLE_ENFORCE_GE(ring_id, 0, platform::errors::OutOfRange(
"ring id must >= 0, but got %d", ring_id));
PADDLE_ENFORCE_LT(ring_id, compute_events_.size(),
platform::errors::OutOfRange(
"ring id must < compute events size,"
"but got ring id = %d, compute events size = %d",
ring_id, compute_events_.size()));

auto compute_stream = static_cast<platform::NPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_))
->stream();
auto comm_stream =
platform::HCCLCommContext::Instance().Get(ring_id, place_)->stream();
auto event = compute_events_[ring_id].get();

// compute_stream-->event-->comm_stream
PADDLE_ENFORCE_NPU_SUCCESS(aclrtRecordEvent(event, compute_stream));
PADDLE_ENFORCE_NPU_SUCCESS(aclrtStreamWaitEvent(comm_stream, event));
}

void HCCLParallelContext::WaitComm(int ring_id) {
PADDLE_ENFORCE_GE(ring_id, 0, platform::errors::OutOfRange(
"ring id must >= 0, but got %d", ring_id));
PADDLE_ENFORCE_LT(ring_id, comm_events_.size(),
platform::errors::OutOfRange(
"ring id must < comm events size,"
"but got ring id = %d, comm events size = %d",
ring_id, comm_events_.size()));

auto compute_stream = static_cast<platform::NPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_))
->stream();
auto comm_stream =
platform::HCCLCommContext::Instance().Get(ring_id, place_)->stream();
auto event = comm_events_[ring_id].get();

// comm_stream-->event-->compute_stream
PADDLE_ENFORCE_NPU_SUCCESS(aclrtRecordEvent(event, comm_stream));
PADDLE_ENFORCE_NPU_SUCCESS(aclrtStreamWaitEvent(compute_stream, event));
}

void HCCLParallelContext::SynchronizeCompute() {
auto *compute_dev_ctx = static_cast<platform::NPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_));
compute_dev_ctx->Wait();
}

} // namespace imperative
} // namespace paddle
71 changes: 71 additions & 0 deletions paddle/fluid/imperative/hccl_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#ifdef PADDLE_WITH_ASCEND_CL
#include <memory>
#include <string>
#include <vector>

#include "paddle/fluid/imperative/parallel_context.h"
#include "paddle/fluid/platform/dynload/hccl.h"
#include "paddle/fluid/platform/npu_resource_pool.h"

namespace paddle {
namespace framework {
class Variable;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace imperative {

class HCCLParallelContext : public ParallelContext {
public:
explicit HCCLParallelContext(const ParallelStrategy& strategy,
const platform::Place& place)
: ParallelContext(strategy, place) {}

~HCCLParallelContext() override = default;

void BcastHCCLId(std::vector<HcclRootInfo>& hccl_ids, int root, // NOLINT
int server_fd);

void Init() override;

void InitWithRingID(int ring_id) override;

void AllReduceByStream(const framework::Variable& src,
framework::Variable* dst, int ring_id,
bool use_calc_stream) override;

paddle::platform::DeviceContext* GetDeviceContext(int ring_id) override;

void WaitCompute(int ring_id) override;

void WaitComm(int ring_id) override;

void SynchronizeCompute() override;

private:
// used for comm wait compute, compute_stream-->event-->comm_stream[ring_id]
std::vector<std::shared_ptr<platform::NpuStreamObject>> compute_events_;

// used for compute wait comm, comm_stream[ring_id]-->event-->compute_stream
std::vector<std::shared_ptr<platform::NpuEventObject>> comm_events_;
};

} // namespace imperative
} // namespace paddle
#endif
20 changes: 20 additions & 0 deletions paddle/fluid/imperative/reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,16 @@ void Group::ConcatTensors(const platform::DeviceContext &context) {
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't concat xpu grads since it's not compiled with BKCL,"
"Please recompile or reinstall Paddle with BKCL support."));
#endif
} else if (platform::is_npu_place(place)) {
#ifdef PADDLE_WITH_ASCEND_CL
ConcatTensorsWithType(
static_cast<const platform::NPUDeviceContext &>(context),
dense_tensors_, &dense_contents_, dtype_);
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't concat npu grads since it's not compiled with HCCL,"
"Please recompile or reinstall Paddle with HCCL support."));
#endif
} else if (platform::is_cpu_place(place)) {
ConcatTensorsWithType(
Expand Down Expand Up @@ -260,6 +270,16 @@ void Group::SplitTensors(const platform::DeviceContext &context) {
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't split xpu grad since it's not compiled with BKCL,"
"Please recompile or reinstall Paddle with BKCL support."));
#endif
} else if (platform::is_npu_place(place)) {
#ifdef PADDLE_WITH_ASCEND_CL
SplitTensorsWithType(
static_cast<const platform::NPUDeviceContext &>(context),
&dense_contents_, &dense_tensors_, dtype_);
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't split npu grad since it's not compiled with HCCL,"
"Please recompile or reinstall Paddle with HCCL support."));
#endif
} else if (platform::is_cpu_place(place)) {
SplitTensorsWithType(
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/operators/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ if (WITH_ASCEND_CL)
endif()

# please add new math_library in alphabetical order
if (WITH_ASCEND_CL)
math_library(concat_and_split DEPS npu_op_runner)
else()
math_library(concat_and_split)
endif()
math_library(context_project DEPS im2col math_function)
math_library(cross_entropy)
math_library(cos_sim_functor)
Expand Down
Loading

0 comments on commit 83e55cf

Please sign in to comment.