Skip to content

Commit

Permalink
support KL2 multi-card training, *test=kunlun
Browse files Browse the repository at this point in the history
    * update xccl lib
    * use separate streams for compute/comm on XPU
    * add broadcast op to xpu2_op_list
  • Loading branch information
XiaociZhang committed Jul 15, 2022
1 parent ec38be6 commit b335f37
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 56 deletions.
5 changes: 4 additions & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ else()
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif()

set(XPU_XCCL_BASE_URL
"https://klx-sdk-release-public.su.bcebos.com/xccl/release/1.0.0")

if(WITH_AARCH64)
set(XPU_XRE_DIR_NAME "xre-kylin_aarch64")
set(XPU_XDNN_DIR_NAME "xdnn-kylin_aarch64")
Expand Down Expand Up @@ -76,7 +79,7 @@ set(XPU_XRE_URL
"${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz"
CACHE STRING "" FORCE)
set(XPU_XCCL_URL
"${XPU_BASE_URL_WITHOUT_DATE}/20220411/${XPU_XCCL_DIR_NAME}.tar.gz"
"${XPU_XCCL_BASE_URL}/${XPU_XCCL_DIR_NAME}.tar.gz"
CACHE STRING "" FORCE)
set(XPU_PACK_DEPENCE_URL
"https://baidu-kunlun-public.su.bcebos.com/paddle_depence/pack_paddle_depence.sh"
Expand Down
39 changes: 33 additions & 6 deletions paddle/fluid/imperative/bkcl_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ void BKCLParallelContext::Init() {
strategy_.local_rank_,
xpu_id,
ring_id);
compute_events_.emplace_back(
platform::XpuEventResourcePool::Instance().New(place_.device));
comm_events_.emplace_back(
platform::XpuEventResourcePool::Instance().New(place_.device));
}
}

Expand All @@ -134,6 +138,11 @@ void BKCLParallelContext::InitWithRingID(int ring_id) {
// it will assign bkcl_comm in XPUDeviceContext within ring_id
platform::BKCLCommContext::Instance().CreateComm(
&bkcl_ids[0], strategy_.nranks_, strategy_.local_rank_, xpu_id, ring_id);

compute_events_.emplace_back(
platform::XpuEventResourcePool::Instance().New(place_.device));
comm_events_.emplace_back(
platform::XpuEventResourcePool::Instance().New(place_.device));
}

void BKCLParallelContext::AllReduceByStream(const framework::Variable &src,
Expand Down Expand Up @@ -213,9 +222,18 @@ void BKCLParallelContext::WaitCompute(int ring_id) {
"but got ring id = %d, nrings = %d",
ring_id,
strategy_.nrings_));
auto compute_dev_ctx = static_cast<platform::XPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_));
compute_dev_ctx->Wait();
auto compute_stream = static_cast<platform::XPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_))
->stream();
auto comm_stream = platform::BKCLCommContext::Instance()
.Get(ring_id, place_)
->dev_context()
->stream();
auto event = compute_events_[ring_id].get();

// compute_stream-->event-->comm_stream
PADDLE_ENFORCE_XPU_SUCCESS(xpu_event_record(event, compute_stream));
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_wait_event(comm_stream, event));
}

void BKCLParallelContext::WaitComm(int ring_id) {
Expand All @@ -230,9 +248,18 @@ void BKCLParallelContext::WaitComm(int ring_id) {
"but got ring id = %d, nrings = %d",
ring_id,
strategy_.nrings_));
auto comm_dev_ctx =
platform::BKCLCommContext::Instance().Get(ring_id, place_)->dev_context();
comm_dev_ctx->Wait();
auto comm_stream = platform::BKCLCommContext::Instance()
.Get(ring_id, place_)
->dev_context()
->stream();
auto compute_stream = static_cast<platform::XPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_))
->stream();
auto event = compute_events_[ring_id].get();

// comm_stream-->event-->compute_stream
PADDLE_ENFORCE_XPU_SUCCESS(xpu_event_record(event, comm_stream));
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_wait_event(compute_stream, event));
}

void BKCLParallelContext::SynchronizeCompute() {
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/imperative/bkcl_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <vector>

#include "paddle/fluid/imperative/parallel_context.h"
#include "paddle/fluid/platform/device/xpu/xpu_resource_pool.h"
#include "xpu/bkcl.h"

namespace paddle {
Expand Down Expand Up @@ -52,6 +53,13 @@ class BKCLParallelContext : public ParallelContext {
void WaitComm(int ring_id) override;

void SynchronizeCompute() override;

private:
// used for comm wait compute, compute_stream-->event-->comm_stream[ring_id]
std::vector<std::shared_ptr<platform::XpuEventObject>> compute_events_;

// used for compute wait comm, comm_stream[ring_id]-->event-->compute_stream
std::vector<std::shared_ptr<platform::XpuEventObject>> comm_events_;
};

} // namespace imperative
Expand Down
69 changes: 20 additions & 49 deletions paddle/fluid/imperative/reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
#include "paddle/fluid/imperative/parallel_context.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#ifdef PADDLE_WITH_XPU_BKCL
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#endif
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/dense_tensor.h"
namespace paddle {
Expand Down Expand Up @@ -431,10 +434,6 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
VLOG(3) << "Start construct the Reducer ...";
nrings_ = parallel_ctx->GetNRings();
nranks_ = parallel_ctx->GetNRanks();
#ifdef PADDLE_WITH_XPU_BKCL
comm_pool_.reset(new ::ThreadPool(1));
comm_op_count_ = 0;
#endif
// initialize groups
InitializeGroups(group_indices);
for (size_t global_var_index = 0; global_var_index < vars_.size();
Expand Down Expand Up @@ -853,8 +852,23 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) {

#ifdef PADDLE_WITH_XPU_BKCL
if (platform::is_xpu_place(group_tensor.place())) {
// TODO(liuyuhui) support XPU set constant
VLOG(3) << "XPU doesn't support set_constant";
auto dev_ctx = static_cast<platform::XPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_));
if (HasGrad(var_index)) {
auto var_base = vars_[var_index]->GradVarBase();
auto tensor =
var_base->MutableVar()->GetMutable<framework::LoDTensor>();
group_tensor.ShareDataWith(*tensor).Resize(
{static_cast<int64_t>(length)});
} else {
group_tensor.Resize({static_cast<int64_t>(length)});
int r = xpu::constant(dev_ctx->x_context(),
reinterpret_cast<float *>(group_tensor.data()),
group_tensor.numel(),
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx->stream()));
}
}
#elif defined(PADDLE_WITH_CNCL)
if (platform::is_mlu_place(group_tensor.place())) {
Expand Down Expand Up @@ -948,33 +962,7 @@ void Reducer::MarkGroupReady(size_t group_index) {
// so we expose WaitCompute() interface and call
// it here.
parallel_ctx_->WaitCompute(run_order);
#ifdef PADDLE_WITH_XPU_BKCL
{
std::lock_guard<std::mutex> lock(mutex_);
comm_op_count_ += 1; // lock
}
// TODO(liuyuhui): Add try catch to deal with exception later,
// otherwise the main thread will continue to run when an exception is
// thrown in comm_pool_.
auto next_group = next_group_;
comm_pool_->enqueue([this, run_order, next_group, &group] {
auto dev_id = place_.device;
platform::SetXPUDeviceId(dev_id);
FusedAllReduceSchedule(run_order, group, next_group);
{
std::lock_guard<std::mutex> lock(mutex_);
comm_op_count_ -= 1; // lock
cv_.notify_all();
}
});
#elif defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL) || \
defined(PADDLE_WITH_GLOO) || defined(PADDLE_WITH_ASCEND_CL) || \
defined(PADDLE_WITH_CNCL)
FusedAllReduceSchedule(run_order, group, next_group_);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Not compiled with BKCL or NCCL or CNCL or GLOO."));
#endif
}
}

Expand All @@ -997,17 +985,6 @@ void Reducer::FusedAllReduceSchedule(const int run_order,
// group.dense_tensors ---> group.dense_contents_
group.ConcatTensors(dev_context);

// NOTE(liuyuhui): ConcatTensors use communication stream, but BKCL only support
// default stream for communicating, so there exist some problems in
// synchronization. And need to add a WaitComm there.
// TODO(liuyuhui): If BKCL support non-blocking communication, it should be
// fixed as multi gpus card training.
#ifdef PADDLE_WITH_XPU_BKCL
if (platform::is_xpu_place(group.dense_tensors_[0].place())) {
parallel_ctx_->WaitComm(run_order);
}
#endif

group.DivNRanks(dev_context, nranks_);
// Start allreduce
parallel_ctx_->AllReduceByStream(
Expand Down Expand Up @@ -1135,12 +1112,6 @@ bool Reducer::HasGrad(size_t var_index) {
void Reducer::FinalizeBackward() {
groups_need_finalize_ = false;
grad_need_hooks_ = false;
#ifdef PADDLE_WITH_XPU_BKCL
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return comm_op_count_ == 0; });
}
#endif

// Must prevent compute_stream_ starting until all comm streams have finished
for (int i = 0; i < nrings_; ++i) {
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/platform/collective_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,12 @@ BKCLComm* BKCLCommContext::AssignBKCLComm(
BKCLContext_t comm, int nranks, int rank, int dev_id, int ring_id) {
std::unique_ptr<XPUDeviceContext> dev_ctx(
new XPUDeviceContext(XPUPlace(dev_id)));
// used in BKCL as comm_stream, for every dev_id there is
// a comm_stream at each ring. this stream is passed as input var
// when calling collective comm commands like bkcl_all_reduce
XPUStream comm_stream;
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&comm_stream));
dev_ctx->SetXPUStream(comm_stream);

BKCLCommImpl* c = new BKCLCommImpl;
c->set_ring_id(ring_id);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/platform/device/xpu/xpu2_op_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"bilinear_interp_v2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"broadcast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"cast",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/backends/all_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ limitations under the License. */
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#ifdef PADDLE_WITH_XPU
#include "paddle/phi/backends/xpu/xpu_context.h"
#endif

#ifndef PADDLE_WITH_CUSTOM_KERNEL
// TODO(wilber): DeviceContextPool nees include fluid file.
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/backends/xpu/xpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ struct XPUContext::Impl {

const Place& GetPlace() const { return place_; }

void SetStream(XPUStream stream) { context_->xpu_stream = stream; }

xpu::Context* GetXContext() const {
PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
return context_;
Expand Down Expand Up @@ -115,6 +117,8 @@ XPUContext::~XPUContext() = default;

const Place& XPUContext::GetPlace() const { return impl_->GetPlace(); }

void XPUContext::SetXPUStream(XPUStream stream) { impl_->SetStream(stream); }

backends::xpu::XPUVersion XPUContext::xpu_version() const {
return impl_->xpu_version_;
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/backends/xpu/xpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class XPUContext : public DeviceContext {

void SetL3Cache(int l3_size = 14155776);

void SetXPUStream(XPUStream stream);

private:
struct Impl;
std::unique_ptr<Impl> impl_;
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/core/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/onednn/onednn_context.h"
#ifdef PADDLE_WITH_XPU
#include "paddle/phi/backends/xpu/xpu_context.h"
#endif
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
Expand Down

0 comments on commit b335f37

Please sign in to comment.