Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix allreduce_sum potential bugs on NPU. #34462

Merged
merged 6 commits into from
Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/fluid/framework/section_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ void SectionWorker::Run1F1B(std::unique_ptr<GarbageCollector> &gc) {
while (fw_step < startup_steps) {
RunForward(fw_step, gc, unused_vars_);
fw_step += 1;
VLOG(2) << "micro steps fw_step:" << fw_step;
}

// 1f1b phase
Expand All @@ -180,15 +181,18 @@ void SectionWorker::Run1F1B(std::unique_ptr<GarbageCollector> &gc) {

fw_step += 1;
bw_step += 1;
VLOG(2) << "micro steps fw_step:" << fw_step << ", bw_step:" << bw_step;
}

int reserve_bw_send_step = bw_step - 2;
// backward phase
while (bw_step < num_microbatches_) {
RunBackward(bw_step, gc, unused_vars_);
bw_step += 1;
VLOG(2) << "micro steps bw_step:" << bw_step;
}

VLOG(2) << "run update";
RunUpdate(gc, unused_vars_);

if (gc) {
Expand All @@ -203,6 +207,7 @@ void SectionWorker::Run1F1B(std::unique_ptr<GarbageCollector> &gc) {

void SectionWorker::TrainFiles() {
VLOG(5) << "begin section_worker TrainFiles";
VLOG(2) << "mini batch steps:" << batch_id_;

int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector> gc;
Expand Down
87 changes: 79 additions & 8 deletions paddle/fluid/operators/collective/c_allreduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/npu_op_runner.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_XPU_BKCL)
Expand Down Expand Up @@ -119,13 +120,45 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
}
};

#if defined(PADDLE_WITH_ASCEND_CL)
// return true if found_inf_or_nan or return false;
template <typename T>
bool CheckNumerics(const framework::ExecutionContext& exe_ctx,
aclrtStream stream, const paddle::framework::Tensor* in) {
auto& dev_ctx =
exe_ctx.template device_context<paddle::platform::NPUDeviceContext>();
using Tensor = paddle::framework::Tensor;
Tensor out(in->type());
out.Resize(in->dims());
out.mutable_data<T>(dev_ctx.GetPlace());

bool found_inf_data = false;

try {
const auto& runner =
NpuOpRunner("CheckNumerics", {*in}, {out},
{{"message", std::string("check_numberics")}});
runner.Run(stream);
dev_ctx.Wait();
} catch (platform::EnforceNotMet& exception) {
LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!";
found_inf_data = true;
} catch (...) {
LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!";
found_inf_data = true;
}

return found_inf_data;
}
#endif

template <ReduceType red_type, typename T>
class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL)
auto in = ctx.Input<framework::LoDTensor>("X");
auto out = ctx.Output<framework::LoDTensor>("Out");
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
auto place = ctx.GetPlace();
HcclDataType dtype = platform::ToHCCLDataType(in->type());
int64_t numel = in->numel();
Expand All @@ -141,9 +174,10 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
paddle::platform::HCCLCommContext::Instance().Get(ring_id, place);

aclrtStream stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto dev_ctx = static_cast<platform::NPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
if (ctx.Attr<bool>("use_calc_stream")) {
stream = static_cast<platform::NPUDeviceContext*>(dev_ctx)->stream();
stream = dev_ctx->stream();
} else {
stream = comm->stream();
}
Expand Down Expand Up @@ -171,9 +205,46 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
"Invalid reduce type: %d", red_type));
}

VLOG(3) << "begin hccl allreduce, parameter is: "
VLOG(3) << "hccl allreduce, parameter is: "
<< "input num: " << in->dims() << "dtype: " << dtype
<< "hccl_red_type: " << hccl_red_type << ", group is: " << group
<< ", sendbuff:" << sendbuff << ", recvbuff:" << recvbuff
<< ", out_size:" << out->memory_size()
<< ", use_calc_stream:" << ctx.Attr<bool>("use_calc_stream")
<< ", stream:" << stream;

framework::Tensor tmp;
tmp.mutable_data<float>({8}, ctx.GetPlace());

bool check_numerics = false;

auto d_type = in->type();
switch (d_type) {
case framework::proto::VarType::FP16:
case framework::proto::VarType::FP32: {
VLOG(4) << "prepare to FoundNanInf";
check_numerics = CheckNumerics<T>(ctx, dev_ctx->stream(), in);
VLOG(4) << "check_numerics:" << check_numerics;
break;
}
default:
break;
}

if (check_numerics) {
T inf = static_cast<T>(std::numeric_limits<float>::infinity());
VLOG(4) << "fill input data constant inf";
auto dims = in->dims();
auto mutable_in = const_cast<framework::Tensor*>(in);
FillNpuTensorWithConstant<T>(mutable_in, inf);
mutable_in->Resize(dims);
}

VLOG(3) << "hccl allreduce, parameter is: "
<< "input num: " << numel << "dtype: " << dtype
<< "hccl_red_type: " << hccl_red_type << ", group is: " << group;
<< "hccl_red_type: " << hccl_red_type << ", group is: " << group
<< ", sendbuff:" << sendbuff << ", recvbuff:" << recvbuff
<< ", out_size:" << out->memory_size();

PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclAllReduce(
sendbuff, recvbuff, numel, dtype, hccl_red_type, comm->comm(),
Expand All @@ -198,7 +269,7 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace();
BKCLDataType dtype = platform::ToBKCLDataType(in->type());
int64_t numel = in->numel();
const void* sendbuff = in->data<void>();
const void* sendbuff = in->data<T>();
out->Resize(in->dims());
void* recvbuff = out->mutable_data<T>(place);

Expand Down Expand Up @@ -260,7 +331,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace();
ncclDataType_t dtype = platform::ToNCCLDataType(in->type());
int64_t numel = in->numel();
const void* sendbuff = in->data<void>();
const void* sendbuff = in->data<T>();
out->Resize(in->dims());
void* recvbuff = out->mutable_data<T>(place);

Expand Down