Skip to content

Commit

Permalink
add allreduce and broadcast without test (#31024)
Browse files Browse the repository at this point in the history
add allreduce and broadcast without test
  • Loading branch information
lw921014 authored Mar 1, 2021
1 parent 5618f14 commit 9fcdaeb
Show file tree
Hide file tree
Showing 29 changed files with 1,895 additions and 19 deletions.
4 changes: 4 additions & 0 deletions cmake/external/ascend.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ endif()
if(WITH_ASCEND_CL)
set(ASCEND_CL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64)

set(ascend_hccl_lib ${ASCEND_CL_DIR}/libhccl.so)
set(ascendcl_lib ${ASCEND_CL_DIR}/libascendcl.so)
set(acl_op_compiler_lib ${ASCEND_CL_DIR}/libacl_op_compiler.so)
set(ASCEND_CL_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include)
Expand All @@ -73,6 +74,9 @@ if(WITH_ASCEND_CL)
ADD_LIBRARY(ascendcl SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ascendcl PROPERTY IMPORTED_LOCATION ${ascendcl_lib})

ADD_LIBRARY(ascend_hccl SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ascend_hccl PROPERTY IMPORTED_LOCATION ${ascend_hccl_lib})

ADD_LIBRARY(acl_op_compiler SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET acl_op_compiler PROPERTY IMPORTED_LOCATION ${acl_op_compiler_lib})
add_custom_target(extern_ascend_cl DEPENDS ascendcl acl_op_compiler)
Expand Down
7 changes: 5 additions & 2 deletions cmake/external/protobuf.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,11 @@ elseif(WITH_ASCEND_CL AND NOT WITH_ASCEND_CXX11)
SET(PROTOBUF_REPOSITORY https://gitee.com/tianjianhe/protobuf.git)
SET(PROTOBUF_TAG v3.8.0)
else()
SET(PROTOBUF_REPOSITORY ${GIT_URL}/protocolbuffers/protobuf.git)
SET(PROTOBUF_TAG 9f75c5aa851cd877fb0d93ccc31b8567a6706546)
SET(PROTOBUF_REPOSITORY https://gitee.com/tianjianhe/protobuf.git)
SET(PROTOBUF_TAG v3.8.0)

# SET(PROTOBUF_REPOSITORY ${GIT_URL}/protocolbuffers/protobuf.git)
# SET(PROTOBUF_TAG 9f75c5aa851cd877fb0d93ccc31b8567a6706546)
endif()

cache_third_party(${TARGET_NAME}
Expand Down
2 changes: 2 additions & 0 deletions cmake/flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ set(COMMON_FLAGS
-Wno-error=int-in-bool-context # Warning in Eigen gcc 7.2
-Wimplicit-fallthrough=0 # Warning in tinyformat.h
-Wno-error=maybe-uninitialized # Warning in boost gcc 7.2
-Wno-error=nonnull-compare # Warning in boost gcc 7.2
-Wno-error=address # Warning in boost gcc 7.2
${fsanitize}
)

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/memory/allocation/allocator_facade.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class AllocatorFacadePrivate {
InitNaiveBestFitCUDAPinnedAllocator();
#endif
#ifdef PADDLE_WITH_ASCEND_CL
VLOG(3) << "npu num: " <<platform::GetNPUDeviceCount();
for (int dev_id = 0; dev_id < platform::GetNPUDeviceCount(); ++dev_id) {
InitNaiveBestFitNPUAllocator(platform::NPUPlace(dev_id));
}
Expand Down Expand Up @@ -141,6 +142,7 @@ class AllocatorFacadePrivate {
(size > 0 ? (UNLIKELY(FLAGS_use_system_allocator) ? system_allocators_
: allocators_)
: zero_size_allocators_);
VLOG(3) <<size;
auto iter = allocators.find(place);
PADDLE_ENFORCE_NE(iter, allocators.end(),
platform::errors::NotFound(
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/memory/allocation/allocator_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace memory {
namespace allocation {

static AllocatorStrategy GetStrategyFromFlag() {
VLOG(3) << "FLAGS_allocator_strategy" << FLAGS_allocator_strategy;
if (FLAGS_allocator_strategy == "naive_best_fit") {
return AllocatorStrategy::kNaiveBestFit;
}
Expand Down
23 changes: 14 additions & 9 deletions paddle/fluid/operators/collective/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,29 @@ foreach(src ${OPS})
set_source_files_properties(${src} PROPERTIES COMPILE_FLAGS ${COLLECTIVE_COMPILE_FLAGS})
endforeach()

register_operators(EXCLUDES c_gen_nccl_id_op gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})
register_operators(EXCLUDES c_gen_bkcl_id_op gen_bkcl_id_op c_gen_nccl_id_op gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})

if(WITH_NCCL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper)
cc_library(gen_nccl_id_op_helper SRCS gen_nccl_id_op_helper.cc DEPS nccl_common)
op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} gen_nccl_id_op_helper)
op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} gen_nccl_id_op_helper)
op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})
op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})
endif()

if(WITH_ASCEND)
op_library(gen_nccl_id_op)
op_library(c_gen_nccl_id_op)
if(WITH_GLOO)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} gloo_wrapper)
endif()

if(WITH_XPU_BKCL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper)
op_library(c_gen_bkcl_id_op DEPS ${COLLECTIVE_DEPS})
op_library(gen_bkcl_id_op DEPS ${COLLECTIVE_DEPS})
endif()

if(WITH_GLOO)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} gloo_wrapper)
if(WITH_ASCEND_CL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper)
endif()

set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COLLECTIVE_DEPS} PARENT_SCOPE)
set(GLOB_COLLECTIVE_DEPS ${COLLECTIVE_DEPS} CACHE INTERNAL "collective dependency")

cc_test(c_hcom_op_npu_test SRCS c_hcom_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_comm_init_hccl_op c_create_group_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor)
31 changes: 31 additions & 0 deletions paddle/fluid/operators/collective/c_allreduce_max_op_npu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_allreduce_op.h"

namespace paddle {
namespace platform {
struct ASCENDPlace;
struct float16;
} // namespace platform
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_NPU_KERNEL(c_allreduce_max,
ops::CAllReduceOpASCENDKernel<ops::kRedMax, float>,
ops::CAllReduceOpASCENDKernel<ops::kRedMax, int>,
ops::CAllReduceOpASCENDKernel<ops::kRedMax, int8_t>,
ops::CAllReduceOpASCENDKernel<ops::kRedMax, plat::float16>)
31 changes: 31 additions & 0 deletions paddle/fluid/operators/collective/c_allreduce_min_op_npu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_allreduce_op.h"

namespace paddle {
namespace platform {
struct ASCENDPlace;
struct float16;
} // namespace platform
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_NPU_KERNEL(c_allreduce_min,
ops::CAllReduceOpASCENDKernel<ops::kRedMin, float>,
ops::CAllReduceOpASCENDKernel<ops::kRedMin, int>,
ops::CAllReduceOpASCENDKernel<ops::kRedMin, int8_t>,
ops::CAllReduceOpASCENDKernel<ops::kRedMin, plat::float16>)
94 changes: 93 additions & 1 deletion paddle/fluid/operators/collective/c_allreduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif

#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#endif

namespace paddle {
namespace operators {

Expand Down Expand Up @@ -105,6 +110,88 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
}
};

template <ReduceType red_type, typename T>
class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL)
auto in = ctx.Input<framework::LoDTensor>("X");
auto out = ctx.Output<framework::LoDTensor>("Out");

auto place = ctx.GetPlace();
hcclDataType_t dtype = platform::ToHCCLDataType(in->type());

int64_t numel = in->numel();
void* sendbuff = reinterpret_cast<void*>(const_cast<T*>(in->data<T>()));
// void* sendbuff = reinterpret_cast<void*>(const_cast<T*>(in->mutable_data<T>(place)));

out->Resize(in->dims());
// void* recvbuff = reinterpret_cast<void*>(const_cast<T*>(out->data<T>()));
void* recvbuff = reinterpret_cast<void*>(const_cast<T*>(out->mutable_data<T>(place)));
// void* recvbuff = sendbuff;
std::string tag = ctx.Attr<std::string>("tag");
int ring_id = ctx.Attr<int>("ring_id");
// s他的:
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
group = "hccl_world_group";// std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);

auto comm = paddle::platform::HCCLCommContext::Instance().Get();

aclrtStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::NPUDeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}

hcclRedOp_t hccl_red_type = HCCL_REP_OP_SUM;
switch (red_type) {
case kRedSum:
hccl_red_type = HCCL_REP_OP_SUM;
break;

case kRedMax:
hccl_red_type = HCCL_REP_OP_MAX;
break;

case kRedMin:
hccl_red_type = HCCL_REP_OP_MIN;
break;

case kRedProd:
hccl_red_type = HCCL_REP_OP_PROD;
break;

default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid reduce type: %d", red_type));
}


VLOG(3) << "begin hccl allreduce, parameter is: "
<< "input num: " << numel
<< "dtype: " << dtype
<< "hccl_red_type: " << hccl_red_type
<< ", group is: " << group
<< ", tag is " << tag;

printf("sendbuff: %p\n", sendbuff);
printf("recvbuff: %p\n", recvbuff);

// printf("sendbuff: %p, %d\n", sendbuff, ((int*)sendbuff)[0]);
// printf("recvbuff: %p, %d\n", recvbuff, ((int*)recvbuff)[0]);

PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_all_reduce(
tag.c_str(), sendbuff, recvbuff, numel, dtype, hccl_red_type, group.c_str(), (void*)stream));

#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
}
};

template <ReduceType red_type, typename T>
class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
public:
Expand All @@ -114,7 +201,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
auto out = ctx.Output<framework::Tensor>("Out");

auto place = ctx.GetPlace();
ncclDataType_t dtype = platform::ToNCCLDataType(in->type());
ncclDataType_t dtype = platform::ToHCCLDataType(in->type());
int64_t numel = in->numel();
const void* sendbuff = in->data<void>();
out->Resize(in->dims());
Expand Down Expand Up @@ -170,6 +257,11 @@ class CAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "(Tensor) the allreduced result.");
AddAttr<int>("ring_id", "(int default 0) communication ring id.")
.SetDefault(0);
#if defined(PADDLE_WITH_ASCEND_CL)
#pragma message("hccl CAllReduceOpMaker need tag attr")
AddAttr<std::string>("tag", "(string default tag) tag for all reduce.")
.SetDefault("tag");
#endif
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
Expand Down
31 changes: 31 additions & 0 deletions paddle/fluid/operators/collective/c_allreduce_prod_op_npu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_allreduce_op.h"

namespace paddle {
namespace platform {
struct ASCENDPlace;
struct float16;
} // namespace platform
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_NPU_KERNEL(c_allreduce_prod,
ops::CAllReduceOpASCENDKernel<ops::kRedProd, float>,
ops::CAllReduceOpASCENDKernel<ops::kRedProd, int>,
ops::CAllReduceOpASCENDKernel<ops::kRedProd, int8_t>,
ops::CAllReduceOpASCENDKernel<ops::kRedProd, plat::float16>)
31 changes: 31 additions & 0 deletions paddle/fluid/operators/collective/c_allreduce_sum_op_npu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_allreduce_op.h"

namespace paddle {
namespace platform {
struct ASCENDPlace;
struct float16;
} // namespace platform
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_NPU_KERNEL(c_allreduce_sum,
ops::CAllReduceOpASCENDKernel<ops::kRedSum, float>,
ops::CAllReduceOpASCENDKernel<ops::kRedSum, int>,
ops::CAllReduceOpASCENDKernel<ops::kRedSum, int8_t>,
ops::CAllReduceOpASCENDKernel<ops::kRedSum, plat::float16>)
5 changes: 5 additions & 0 deletions paddle/fluid/operators/collective/c_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class CBroadcastOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(0);
AddAttr<int>("root", "(int default 0) root id for broadcasting.")
.SetDefault(0);
#if defined(PADDLE_WITH_ASCEND_CL)
#pragma message("tag")
AddAttr<std::string>("tag", "(string default tag) tag for broadcasting.")
.SetDefault("tag");
#endif
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
Expand Down
Loading

0 comments on commit 9fcdaeb

Please sign in to comment.