Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ascendrc #7

Merged
merged 6 commits into from
Mar 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddle/fluid/operators/collective/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,6 @@ cc_test(c_broadcast_op_npu_test SRCS c_broadcast_op_npu_test.cc DEPS op_registry
cc_test(c_allreduce_sum_op_npu_test SRCS c_allreduce_sum_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_allgather_op c_reducescatter_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} memory memcpy ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor)
cc_test(c_allreduce_max_op_npu_test SRCS c_allreduce_max_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_max_op c_allgather_op c_reducescatter_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} memory memcpy ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor)
cc_test(c_reducescatter_op_npu_test SRCS c_reducescatter_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_allgather_op c_reducescatter_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} memory memcpy ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor)
cc_test(c_allgather_op_npu_test SRCS c_allgather_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_allgather_op c_reducescatter_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} memory memcpy ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor)
cc_test(c_allgather_op_npu_test SRCS c_allgather_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_allgather_op c_reducescatter_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} memory memcpy ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor)
cc_test(send_v2_op_npu_test SRCS send_v2_op_npu_test.cc DEPS op_registry send_v2_op recv_v2_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor)
cc_test(recv_v2_op_npu_test SRCS recv_v2_op_npu_test.cc DEPS op_registry send_v2_op recv_v2_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor)
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ TEST(c_reducescatter, NPU) {
char * npu_id=getenv("FLAGS_selected_npus");

p::NPUDeviceContext ctx(p::NPUPlace(atoi(npu_id)));

Prepare(&scope, ctx);
TestHCCLReduceScatterOp(&scope, ctx);
}
8 changes: 8 additions & 0 deletions paddle/fluid/operators/collective/recv_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ class RecvOpV2Maker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("peer", "(int default 0) rank id for sender.").SetDefault(0);
AddAttr<int>("dtype", "(int default 5('float32')) data type of tensor.")
.SetDefault(5);
#if defined(PADDLE_WITH_ASCEND_CL)
#pragma message("tag")
AddAttr<std::string>("tag", "(string default tag) tag for broadcasting.")
.SetDefault("tag");
#pragma message("srTag")
AddAttr<int>("srTag", "(string default tag) tag for broadcasting.")
.SetDefault(0);
#endif
AddAttr<std::vector<int>>("out_shape", "shape of the output tensor.")
.SetDefault(std::vector<int>());
AddAttr<bool>(
Expand Down
72 changes: 72 additions & 0 deletions paddle/fluid/operators/collective/recv_v2_op_npu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/recv_v2_op.h"

#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#endif

namespace paddle {
namespace operators {

template <typename T>
class CRecvOpASCENDKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL)
auto out = ctx.Output<framework::LoDTensor>("Out");
int numel = out->numel();
hcclDataType_t dtype = platform::ToHCCLDataType(out->type());

int ring_id = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place);

aclrtStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::NPUDeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
std::string tag = ctx.Attr<std::string>("tag");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
int srcRank = ctx.Attr<int>("peer");
int srTag = ctx.Attr<int>("srTag");
VLOG(3) << "recv_v2_npu attr get";
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_receive(
tag.c_str(), reinterpret_cast<void*>(const_cast<T*>(out->data<T>())), numel, dtype, srcRank,
srTag, group.c_str(), stream));
VLOG(3) << "Source Rank: " << srcRank << " Invoke hcom receive. receiving ";
out->Resize(out->dims());
out->set_lod(out->lod());
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with Ascend."));
#endif
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_NPU_KERNEL(recv_v2, ops::CRecvOpASCENDKernel<float>,
ops::CRecvOpASCENDKernel<int>,
ops::CRecvOpASCENDKernel<int8_t>,
ops::CRecvOpASCENDKernel<plat::float16>);
125 changes: 125 additions & 0 deletions paddle/fluid/operators/collective/recv_v2_op_npu_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef _WIN32
#include <unistd.h>
#endif

#include <string>
#include <thread> // NOLINT
#include <vector>
#include <stdio.h>

#include "gtest/gtest.h"

#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/math/math_function.h"

#include "paddle/fluid/operators/collective/recv_v2_op.h"

#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#endif

namespace f = paddle::framework;
namespace p = paddle::platform;
namespace m = paddle::operators::math;

USE_OP(recv_v2);
USE_NO_KERNEL_OP(c_comm_init_hcom);
USE_OP_DEVICE_KERNEL(recv_v2, NPU);

void Prepare(f::Scope* scope, const p::DeviceContext& ctx){

std::string rank_table_file = getenv("RANK_TABLE_FILE");
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
int src_rank = atoi(getenv("SRC_RANK"));
int dest_rank = atoi(getenv("DEST_RANK"));
VLOG(3)<<"rank_id "<< rank_id << "src_rank"<< src_rank <<"dest_rank" <<dest_rank;

std::vector<int> rank_ids = {0,1};
f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2;
comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs);
VLOG(3) << "CreateOp c_comm_init_hcom";
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
}

void TestHcomRecvOp(f::Scope* scope, const p::DeviceContext& ctx){
std::cout << "BEGIN TEST:" << __FUNCTION__ << std::endl;

int num = atoi(getenv("DATA_SIZE"));
EXPECT_GT(num, 0);
EXPECT_LT(num, 1 << 15);
int rank_id = atoi(getenv("RANK_ID"));
VLOG(3) << "rank_id:" << rank_id<<std::endl;
std::cout<<std::endl;

ctx.Wait();
auto place = ctx.GetPlace();
auto out = scope->Var("Out");
auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num, num});
tensor_out->mutable_data<float>(place); // allocate

ctx.Wait();

f::AttributeMap attrs;
attrs["tag"]=std::string("srtest");
attrs["peer"]=atoi(getenv("SRC_RANK"));
attrs["ring_id"]=0;
attrs["srTag"]=0;
std::vector<int> out_shape;
out_shape.push_back(num);
out_shape.push_back(num);
attrs["out_shape"]=out_shape;

auto op =
f::OpRegistry::CreateOp("recv_v2", {}, {{"Out", {"Out"}}}, attrs);
VLOG(3) << "CreateOp recv_v2";

op->Run(*scope, place);
VLOG(3) << "Run op recv_v2";
std::vector<float> out_vec;
TensorToVector(*tensor_out, ctx, &out_vec);
ctx.Wait();
std::vector<float> init(num*num, 1.0 * atoi(getenv("DEST_RANK")));
EXPECT_EQ(out_vec == init, true);
}


TEST(recv_v2, NPU){
f::Scope scope;
char * npu_id=getenv("FLAGS_selected_npus");
VLOG(3) << "Select npu:" << npu_id;
p::NPUDeviceContext ctx(p::NPUPlace(atoi(npu_id)));
VLOG(3) << "Place over";
Prepare(&scope, ctx);
VLOG(3) << "Prepare over";
TestHcomRecvOp(&scope, ctx);
VLOG(3) << "Test over";
}
8 changes: 8 additions & 0 deletions paddle/fluid/operators/collective/send_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ class SendOpV2Maker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.")
.SetDefault(0);
AddAttr<int>("peer", "(int default 0) rank id for receiver.").SetDefault(0);
#if defined(PADDLE_WITH_ASCEND_CL)
#pragma message("tag")
AddAttr<std::string>("tag", "(string default tag) tag for broadcasting.")
.SetDefault("tag");
#pragma message("srTag")
AddAttr<int>("srTag", "(string default tag) tag for broadcasting.")
.SetDefault(0);
#endif
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
Expand Down
73 changes: 73 additions & 0 deletions paddle/fluid/operators/collective/send_v2_op_npu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/send_v2_op.h"

#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#endif

namespace paddle {
namespace operators {

template <typename T>
class CSendOpASCENDKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL)
auto x = ctx.Input<framework::LoDTensor>("X");
int numel = x->numel();
hcclDataType_t dtype = platform::ToHCCLDataType(x->type());

auto place = ctx.GetPlace();
int ring_id = ctx.Attr<int>("ring_id");
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place);

aclrtStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::NPUDeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
std::string tag = ctx.Attr<std::string>("tag");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
int destRank = ctx.Attr<int>("peer");
int srTag = ctx.Attr<int>("srTag");

PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_send(
tag.c_str(), reinterpret_cast<void*>(const_cast<T*>(x->data<T>())), numel, dtype, destRank,
srTag, group.c_str(), stream));

VLOG(3) << "Dest rank:" << destRank << " Invoke hcom send. Sent "
<< x->numel();

#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with Ascend."));
#endif
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_NPU_KERNEL(send_v2, ops::CSendOpASCENDKernel<float>,
ops::CSendOpASCENDKernel<int>,
ops::CSendOpASCENDKernel<int8_t>,
ops::CSendOpASCENDKernel<plat::float16>);
Loading