-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add all ccl op for Ascendrc #31437
add all ccl op for Ascendrc #31437
Changes from 42 commits
2d4a8b8
3e7c453
c287eb3
8ff3c5b
3a059f4
9f862dd
73d490e
4717777
f2039c6
e0cee0d
cfd2f0c
b334c7d
5a79406
d3f1b16
c5140e7
30ed979
6583046
cec9f15
7fdf5d7
8fccb14
284d1d2
12014fc
4bafcd8
2760d95
b96578b
25277e6
0d96158
f6d7070
e167570
6912c23
a016e95
97b40f9
6116b12
2168739
ed6bc36
5e708b8
8aff81a
8de7f6b
3a00bd3
7f9312e
bf9f79c
d5929a1
17e7697
db4b0ae
4917c17
71344fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,10 +30,18 @@ if(WITH_XPU_BKCL) | |
endif() | ||
|
||
if(WITH_ASCEND_CL) | ||
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper) | ||
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper) | ||
endif() | ||
|
||
set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COLLECTIVE_DEPS} PARENT_SCOPE) | ||
set(GLOB_COLLECTIVE_DEPS ${COLLECTIVE_DEPS} CACHE INTERNAL "collective dependency") | ||
|
||
cc_test(c_hcom_op_npu_test SRCS c_hcom_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor) | ||
if(WITH_ASCEND_CL) | ||
cc_test(c_broadcast_op_npu_test SRCS c_broadcast_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_allgather_op c_reducescatter_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} memory memcpy ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor) | ||
cc_test(c_allreduce_sum_op_npu_test SRCS c_allreduce_sum_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_allgather_op c_reducescatter_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} memory memcpy ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor) | ||
cc_test(c_allreduce_max_op_npu_test SRCS c_allreduce_max_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_max_op c_allgather_op c_reducescatter_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} memory memcpy ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor) | ||
cc_test(c_reducescatter_op_npu_test SRCS c_reducescatter_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_allgather_op c_reducescatter_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} memory memcpy ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor) | ||
cc_test(c_allgather_op_npu_test SRCS c_allgather_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_allgather_op c_reducescatter_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} memory memcpy ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor) | ||
cc_test(send_v2_op_npu_test SRCS send_v2_op_npu_test.cc DEPS op_registry send_v2_op recv_v2_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor) | ||
cc_test(recv_v2_op_npu_test SRCS recv_v2_op_npu_test.cc DEPS op_registry send_v2_op recv_v2_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor) | ||
endif() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a new line after the last line and the symbol will dismiss. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good! Simple code can be more freindly. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/collective/c_allgather_op.h" | ||
|
||
#include <memory> | ||
|
||
#if defined(PADDLE_WITH_ASCEND_CL) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good commet! |
||
#include "paddle/fluid/platform/collective_helper.h" | ||
#include "paddle/fluid/platform/hccl_helper.h" | ||
#endif | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
class CAllGatherOpASCENDKernel : public framework::OpKernel<T> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggest CAllGatherOpASCENDKernel -> AllGatherOpNPUKernel, removed prefix "C" since it looks like word "CAll". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We‘d better keep this style, because other ops under this path start with 'C'. Maybe it means these ops were implented by c/c++ language. |
||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
#if defined(PADDLE_WITH_ASCEND_CL) | ||
auto in = ctx.Input<framework::Tensor>("X"); | ||
auto out = ctx.Output<framework::Tensor>("Out"); | ||
hcclDataType_t dtype = platform::ToHCCLDataType(in->type()); | ||
|
||
int ring_id = ctx.Attr<int>("ring_id"); | ||
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); | ||
std::string tag = ctx.Attr<std::string>("tag"); | ||
auto place = ctx.GetPlace(); | ||
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggest adding HCCLCommContext to NPUDeviceContext in the future. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. handsome! |
||
int nranks = comm->nranks(); | ||
|
||
framework::DDim out_dims = in->dims(); | ||
out_dims[0] *= nranks; | ||
out->mutable_data<T>(out_dims, place); | ||
|
||
int64_t send_numel = in->numel(); | ||
void *send_buff = reinterpret_cast<void*>(const_cast<T*>(in->data<T>())); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just delete it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need it somewhere, but we do not need it for most time, I have delete most of unneeded case. |
||
void *recv_buff = reinterpret_cast<void*>(const_cast<T*>(out->data<T>())); | ||
|
||
aclrtStream stream = nullptr; | ||
if (ctx.Attr<bool>("use_calc_stream")) { | ||
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); | ||
stream = static_cast<platform::NPUDeviceContext*>(dev_ctx)->stream(); | ||
} else { | ||
stream = comm->stream(); | ||
} | ||
|
||
VLOG(3) << "begin hccl allgather, parameter is: " | ||
<< ", group is " << group | ||
<< ", ring_id is " << ring_id | ||
<< ", nranks is " << nranks | ||
<< ", tag is " << tag; | ||
|
||
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_all_gather( | ||
tag.c_str(), send_buff, recv_buff, (u64)send_numel, dtype, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. u64 is uint64_t ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typedef unsigned long long u64; |
||
group.c_str(), (void*)stream)); | ||
|
||
#else | ||
PADDLE_THROW(platform::errors::PreconditionNotMet( | ||
"PaddlePaddle should compile with NPU.")); | ||
#endif | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; | ||
|
||
REGISTER_OP_NPU_KERNEL(c_allgather, | ||
ops::CAllGatherOpASCENDKernel<int8_t>, | ||
ops::CAllGatherOpASCENDKernel<int>, | ||
ops::CAllGatherOpASCENDKernel<float>, | ||
ops::CAllGatherOpASCENDKernel<plat::float16>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#ifndef _WIN32 | ||
#include <unistd.h> | ||
#endif | ||
|
||
#include <string> | ||
#include <thread> // NOLINT | ||
#include <vector> | ||
#include <stdio.h> | ||
|
||
#include "gtest/gtest.h" | ||
|
||
#include "paddle/fluid/string/printf.h" | ||
#include "paddle/fluid/framework/operator.h" | ||
#include "paddle/fluid/operators/dropout_op.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/program_desc.h" | ||
#include "paddle/fluid/operators/math/math_function.h" | ||
|
||
#include "paddle/fluid/operators/collective/c_broadcast_op.h" | ||
#include "paddle/fluid/operators/collective/c_allreduce_op.h" | ||
#include "paddle/fluid/operators/collective/c_allgather_op.h" | ||
#include "paddle/fluid/operators/collective/c_reducescatter_op.h" | ||
|
||
#if defined(PADDLE_WITH_ASCEND_CL) | ||
#include "paddle/fluid/platform/collective_helper.h" | ||
#include "paddle/fluid/platform/hccl_helper.h" | ||
#endif | ||
|
||
namespace f = paddle::framework; | ||
namespace p = paddle::platform; | ||
namespace m = paddle::operators::math; | ||
|
||
USE_OP(c_allgather); | ||
USE_NO_KERNEL_OP(c_comm_init_hcom); | ||
USE_OP_DEVICE_KERNEL(c_allgather, NPU); | ||
|
||
template<typename T> | ||
void PrintDebugInfo(std::string preStr, std::vector<T> &data){ | ||
std::string debugstring = ""; | ||
for (auto ele : data) { | ||
debugstring += std::to_string(ele) + std::string(","); | ||
} | ||
VLOG(2) << preStr << ":" << std::endl <<debugstring; | ||
} | ||
|
||
void Prepare(f::Scope* scope, const p::DeviceContext& ctx){ | ||
|
||
int rank_id = atoi(getenv("RANK_ID")); | ||
int device_id = atoi(getenv("DEVICE_ID")); | ||
|
||
VLOG(2) << "rank_id = " << rank_id | ||
<< "; device_id = " << device_id | ||
<< "; rank_id = " << rank_id | ||
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID")); | ||
|
||
std::vector<int> rank_ids{0, 1}; | ||
f::AttributeMap comm_init_attrs; | ||
comm_init_attrs["ring_id"] = 0; | ||
comm_init_attrs["nranks"] = 2; | ||
comm_init_attrs["rank"] = rank_id; | ||
comm_init_attrs["device_id"] = device_id; | ||
comm_init_attrs["rank_ids"] = rank_ids; | ||
auto comm_init_op = | ||
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs); | ||
auto place = ctx.GetPlace(); | ||
comm_init_op->Run(*scope, place); | ||
ctx.Wait(); | ||
} | ||
|
||
void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) { | ||
// init | ||
auto x = scope->Var("X"); | ||
auto tensor_x = x->GetMutable<f::LoDTensor>(); | ||
|
||
std::vector<float> init; | ||
int rank_id = atoi(getenv("RANK_ID")); | ||
|
||
int num1 = 1; | ||
int num2 = 4; | ||
|
||
for (int64_t i = 0; i < num1 * num2; ++i) { | ||
init.push_back(1.0 + rank_id); | ||
} | ||
PrintDebugInfo("input data", init); | ||
|
||
TensorFromVector(init, ctx, tensor_x); | ||
tensor_x->Resize({num1, num2}); | ||
ctx.Wait(); | ||
|
||
auto place = ctx.GetPlace(); | ||
auto out = scope->Var("Out"); | ||
auto tensor_out = out->GetMutable<f::LoDTensor>(); | ||
tensor_out->Resize({num1, num2}); | ||
tensor_out->mutable_data<float>(place); // allocate | ||
ctx.Wait(); | ||
|
||
// run | ||
f::AttributeMap attrs; | ||
attrs["tag"]=std::string("tagx"); | ||
attrs["ring_id"]=0; | ||
attrs["nranks"]=2; | ||
|
||
auto op = f::OpRegistry::CreateOp("c_allgather", {{"X", {"X"}}}, | ||
{{"Out", {"Out"}}}, attrs); | ||
|
||
op->Run(*scope, place); | ||
ctx.Wait(); | ||
|
||
std::vector<float> out_vec; | ||
TensorToVector(*tensor_out, ctx, &out_vec); | ||
ctx.Wait(); | ||
|
||
PrintDebugInfo("output data", out_vec); | ||
|
||
EXPECT_EQ(out_vec.size(), init.size() * 2); | ||
for (uint32_t i = 0; i < out_vec.size() / 2; i++) { | ||
EXPECT_EQ(out_vec[i], 1.0); | ||
} | ||
for (uint32_t i = out_vec.size() / 2; i < out_vec.size(); i++) { | ||
EXPECT_EQ(out_vec[i], 2.0); | ||
} | ||
} | ||
|
||
|
||
TEST(c_allgather, NPU) { | ||
f::Scope scope; | ||
char * npu_id=getenv("FLAGS_selected_npus"); | ||
|
||
p::NPUDeviceContext ctx(p::NPUPlace(atoi(npu_id))); | ||
|
||
Prepare(&scope, ctx); | ||
TestHCCLAllGatherOp(&scope, ctx); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems it is not necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good commet!