Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add all ccl op for Ascendrc #31437

Merged
merged 46 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
2d4a8b8
add allreduce and broadcast without test
lw921014 Feb 18, 2021
3e7c453
add c_broadcast_test case
lw921014 Feb 19, 2021
c287eb3
build c_comm_init and c_create_group operators
void-main Feb 19, 2021
8ff3c5b
make the whole thing compile
void-main Feb 19, 2021
3a059f4
add broadcast and init op test case but run failed
lw921014 Feb 19, 2021
9f862dd
make unit test compile
void-main Feb 19, 2021
73d490e
fix broadcast test bug and change into hcom for ccl
lw921014 Feb 19, 2021
4717777
change c_comm_init and c_create_group ops accordingly
void-main Feb 19, 2021
f2039c6
make tests compile
void-main Feb 19, 2021
e0cee0d
transfer code to 27
lw921014 Feb 22, 2021
cfd2f0c
compiled successfully in 28, but run failed
void-main Mar 1, 2021
b334c7d
test broadcast in 28, but failed
lw921014 Feb 26, 2021
5a79406
make hcom primitives work
void-main Feb 27, 2021
d3f1b16
change hccl data type for base.h
lw921014 Mar 1, 2021
c5140e7
fix broadcast bug
lw921014 Mar 1, 2021
30ed979
make attributes work
void-main Mar 1, 2021
6583046
fix group name bug
lw921014 Mar 1, 2021
cec9f15
add allreduce but test failed
lw921014 Mar 1, 2021
7fdf5d7
allreduce bug for qiuliang
lw921014 Mar 1, 2021
8fccb14
allreduce finished
lw921014 Mar 2, 2021
284d1d2
add allgather and reducescatter
lw921014 Mar 2, 2021
12014fc
ccl op mergered
lw921014 Mar 2, 2021
4bafcd8
merge all op code
lw921014 Mar 2, 2021
2760d95
add allgather test
lw921014 Mar 2, 2021
b96578b
finish run all ccl op test exclude send/recv
lw921014 Mar 3, 2021
25277e6
all all op and test exclude send/recv
lw921014 Mar 3, 2021
0d96158
send_v2_npu.cc recv_v2_npiu.cc compiled
f2hkop Mar 3, 2021
f6d7070
fix ccl core dump bug and test allgather, reducescatter, broadcast op
lw921014 Mar 3, 2021
e167570
fix allreduce bug just for test
lw921014 Mar 4, 2021
6912c23
hcom send&recv test pass, without hcom_destroy
f2hkop Mar 4, 2021
a016e95
Merge branch 'ascendrc' of https://github.com/lw921014/Paddle into as…
f2hkop Mar 4, 2021
97b40f9
for qiuliang test
lw921014 Mar 4, 2021
6116b12
Ascend Send&Recv Test Pass
f2hkop Mar 4, 2021
2168739
all op (ex send/recv) ok
lw921014 Mar 4, 2021
ed6bc36
merge all reduce sum changes
f2hkop Mar 4, 2021
5e708b8
fix bug
lw921014 Mar 4, 2021
8aff81a
Merge branch 'ascendrc' of https://github.com/lw921014/Paddle into as…
f2hkop Mar 4, 2021
8de7f6b
Merge pull request #7 from f2hkop/ascendrc
lw921014 Mar 4, 2021
3a00bd3
merge all ccl op
lw921014 Mar 4, 2021
7f9312e
Merge branch 'ascendrc' of https://github.com/PaddlePaddle/Paddle int…
lw921014 Mar 5, 2021
bf9f79c
style merge to PaddlePaddle
lw921014 Mar 5, 2021
d5929a1
merge style
lw921014 Mar 5, 2021
17e7697
new merge style
lw921014 Mar 5, 2021
db4b0ae
merge style 2
lw921014 Mar 6, 2021
4917c17
insert an empty at the end
lw921014 Mar 8, 2021
71344fa
disable ctest for hcom to pass ci
lw921014 Mar 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions paddle/fluid/operators/collective/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,27 @@ if(WITH_XPU_BKCL)
endif()

if(WITH_ASCEND_CL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper)
endif()

set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COLLECTIVE_DEPS} PARENT_SCOPE)
set(GLOB_COLLECTIVE_DEPS ${COLLECTIVE_DEPS} CACHE INTERNAL "collective dependency")

cc_test(c_hcom_op_npu_test SRCS c_hcom_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor)
if(WITH_ASCEND_CL)
set(COMMON_TEST_DEPS_FOR_HCOM c_comm_init_hcom_op op_registry ascend_hccl flags
dynamic_loader dynload_warpctc scope device_context enforce executor)
cc_test(c_broadcast_op_npu_test SRCS c_broadcast_op_npu_test.cc
DEPS c_broadcast_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_allreduce_sum_op_npu_test SRCS c_allreduce_sum_op_npu_test.cc
DEPS c_allreduce_sum_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_allreduce_max_op_npu_test SRCS c_allreduce_max_op_npu_test.cc
DEPS c_allreduce_max_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_reducescatter_op_npu_test SRCS c_reducescatter_op_npu_test.cc
DEPS c_reducescatter_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_allgather_op_npu_test SRCS c_allgather_op_npu_test.cc
DEPS c_allgather_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(send_v2_op_npu_test SRCS send_v2_op_npu_test.cc
DEPS send_v2_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(recv_v2_op_npu_test SRCS recv_v2_op_npu_test.cc
DEPS recv_v2_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
endif()
4 changes: 4 additions & 0 deletions paddle/fluid/operators/collective/c_allgather_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class CAllGatherOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "(Tensor) the allgather result");
AddAttr<int>("ring_id", "(int default 0) communication ring id.")
.SetDefault(0);
#if defined(PADDLE_WITH_ASCEND_CL)
AddAttr<std::string>("tag", "(string default tag) tag for all gather.")
.SetDefault("tag");
#endif
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
Expand Down
86 changes: 86 additions & 0 deletions paddle/fluid/operators/collective/c_allgather_op_npu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_allgather_op.h"

#include <memory>

#if defined(PADDLE_WITH_ASCEND_CL)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good commet!

#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#endif

namespace paddle {
namespace operators {

template <typename T>
class CAllGatherOpASCENDKernel : public framework::OpKernel<T> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest CAllGatherOpASCENDKernel -> AllGatherOpNPUKernel, removed prefix "C" since it looks like word "CAll".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We‘d better keep this style, because other ops under this path start with 'C'. Maybe it means these ops were implented by c/c++ language.

public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
hcclDataType_t dtype = platform::ToHCCLDataType(in->type());

int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
std::string tag = ctx.Attr<std::string>("tag");
auto place = ctx.GetPlace();
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest adding HCCLCommContext to NPUDeviceContext in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handsome!

int nranks = comm->nranks();

framework::DDim out_dims = in->dims();
out_dims[0] *= nranks;
out->mutable_data<T>(out_dims, place);

int64_t send_numel = in->numel();
void *send_buff = reinterpret_cast<void*>(const_cast<T*>(in->data<T>()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is const_cast necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just delete it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need it somewhere, but we do not need it for most time, I have delete most of unneeded case.

void *recv_buff = reinterpret_cast<void*>(out->data<T>());

aclrtStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::NPUDeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}

VLOG(3) << "begin hccl allgather, parameter is: "
<< ", group is " << group
<< ", ring_id is " << ring_id
<< ", nranks is " << nranks
<< ", tag is " << tag;

PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_all_gather(
tag.c_str(), send_buff, recv_buff, (u64)send_numel, dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

u64 is uint64_t ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typedef unsigned long long u64;
it is defined by huawei in paddle/fluid/platform/dynload/hcom_type.h.

group.c_str(), (void*)stream));

#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with NPU."));
#endif
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_NPU_KERNEL(c_allgather,
ops::CAllGatherOpASCENDKernel<int8_t>,
ops::CAllGatherOpASCENDKernel<int>,
ops::CAllGatherOpASCENDKernel<float>,
ops::CAllGatherOpASCENDKernel<plat::float16>);
149 changes: 149 additions & 0 deletions paddle/fluid/operators/collective/c_allgather_op_npu_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef _WIN32
#include <unistd.h>
#endif

#include <string>
#include <thread> // NOLINT
#include <vector>
#include <stdio.h>

#include "gtest/gtest.h"

#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/math/math_function.h"

#include "paddle/fluid/operators/collective/c_broadcast_op.h"
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
#include "paddle/fluid/operators/collective/c_allgather_op.h"
#include "paddle/fluid/operators/collective/c_reducescatter_op.h"

#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#endif

namespace f = paddle::framework;
namespace p = paddle::platform;
namespace m = paddle::operators::math;

USE_OP(c_allgather);
USE_NO_KERNEL_OP(c_comm_init_hcom);
USE_OP_DEVICE_KERNEL(c_allgather, NPU);

DECLARE_string(selected_npus);

template<typename T>
void PrintDebugInfo(const std::string preStr, const std::vector<T> &data){
std::string debugstring = "";
for (auto ele : data) {
debugstring += std::to_string(ele) + std::string(",");
}
VLOG(2) << preStr << ":" << std::endl <<debugstring;
}

void Prepare(f::Scope* scope, const p::DeviceContext& ctx){

int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));

VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("RANK_TABLE_FILE"));

std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2;
comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs);
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
}

void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) {
// init
auto x = scope->Var("X");
auto tensor_x = x->GetMutable<f::LoDTensor>();

std::vector<float> init;
int rank_id = atoi(getenv("RANK_ID"));

int num1 = 1;
int num2 = 4;

for (int64_t i = 0; i < num1 * num2; ++i) {
init.push_back(1.0 + rank_id);
}
PrintDebugInfo("input data", init);

TensorFromVector(init, ctx, tensor_x);
tensor_x->Resize({num1, num2});
ctx.Wait();

auto place = ctx.GetPlace();
auto out = scope->Var("Out");
auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num1, num2});
tensor_out->mutable_data<float>(place); // allocate
ctx.Wait();

// run
f::AttributeMap attrs;
attrs["tag"]=std::string("tagx");
attrs["ring_id"]=0;
attrs["nranks"]=2;

auto op = f::OpRegistry::CreateOp("c_allgather", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs);

op->Run(*scope, place);
ctx.Wait();

std::vector<float> out_vec;
TensorToVector(*tensor_out, ctx, &out_vec);
ctx.Wait();

PrintDebugInfo("output data", out_vec);

EXPECT_EQ(out_vec.size(), init.size() * 2);
for (uint32_t i = 0; i < out_vec.size() / 2; i++) {
EXPECT_EQ(out_vec[i], 1.0);
}
for (uint32_t i = out_vec.size() / 2; i < out_vec.size(); i++) {
EXPECT_EQ(out_vec[i], 2.0);
}
}


TEST(c_allgather, NPU) {
f::Scope scope;

// only support one device, if more than one device, use first default
p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));

Prepare(&scope, ctx);
TestHCCLAllGatherOp(&scope, ctx);
}
4 changes: 2 additions & 2 deletions paddle/fluid/operators/collective/c_allreduce_max_op_npu.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -25,7 +25,7 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_NPU_KERNEL(c_allreduce_max,
ops::CAllReduceOpASCENDKernel<ops::kRedMax, float>,
ops::CAllReduceOpASCENDKernel<ops::kRedMax, int>,
ops::CAllReduceOpASCENDKernel<ops::kRedMax, int8_t>,
ops::CAllReduceOpASCENDKernel<ops::kRedMax, float>,
ops::CAllReduceOpASCENDKernel<ops::kRedMax, plat::float16>)
Loading