Skip to content

Commit

Permalink
[cherry-pick] fix gather bug && fix hang of new_group (#33553)
Browse files Browse the repository at this point in the history
* Fix gather infer shape using axis (#33413)

* fix gather shape bug

* fix None

* fix topo

* Fix hang of hybrid parallel in new_group  (#33141)

* fix hang of hybrid parallel

* fix new_group for hang problem

* fix hang
  • Loading branch information
ForFishes authored Jun 15, 2021
1 parent 036f81f commit a4e841e
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 213 deletions.
26 changes: 7 additions & 19 deletions paddle/fluid/operators/gather.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,25 +202,20 @@ __global__ void GatherGradGPUKernel(const T* input, const U* index, T* out,
}
}

template <typename T, typename U, typename V>
template <typename T, typename U>
void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
const Tensor* axis, Tensor* out,
const int axis, Tensor* out,
const paddle::platform::Place& place,
const framework::ExecutionContext& ctx) {
int axis_size = axis->numel();
int index_size = index->numel();
int input_size = input->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();
auto* index_data = index->data<U>();

if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1,
platform::errors::InvalidArgument(
"Axis size should be 1, but received %d", axis_size));
Tensor cpu_axis;
framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis);
int axis_index = cpu_axis.data<V>()[0];

int axis_index = axis;
int index_dim_size = input_dim[axis_index];

int inner_dim_size = 1;
Expand Down Expand Up @@ -251,26 +246,19 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
index_size, index_dim_size, out_size);
}

template <typename T, typename U, typename V>
template <typename T, typename U>
void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index,
const Tensor* axis, Tensor* out,
const int axis, Tensor* out,
const paddle::platform::Place& place,
const framework::ExecutionContext& ctx) {
auto* index_data = index->data<U>();

int axis_size = axis->numel();
int index_size = index->numel();
int input_size = input->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();

if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1,
platform::errors::InvalidArgument(
"Axis size should be 1, but received %d", axis_size));
Tensor cpu_axis;
framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis);
int axis_index = cpu_axis.data<V>()[0];
int axis_index = axis;
int input_index_dim_size = input_dim[axis_index];

int inner_dim_size = 1;
Expand Down
26 changes: 7 additions & 19 deletions paddle/fluid/operators/gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,24 +126,17 @@ void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input,
}
}

template <typename T, typename U, typename V>
void GatherV2Function(const Tensor* input, const Tensor* index,
const Tensor* axis, Tensor* out,
const paddle::platform::Place& place) {
auto* axis_data = axis->data<V>();
template <typename T, typename U>
void GatherV2Function(const Tensor* input, const Tensor* index, int axis,
Tensor* out, const paddle::platform::Place& place) {
auto* index_data = index->data<U>();

int axis_size = axis->numel();
int index_size = index->numel();
int input_size = input->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();

if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1,
platform::errors::InvalidArgument(
"Axis size should be 1, but received %d", axis_size));
int axis_index = axis_data[0];
int axis_index = axis;

int input_index_dim_size = input_dim[axis_index];
for (int i = 0; i < index_size; i++) {
Expand Down Expand Up @@ -186,22 +179,17 @@ void GatherV2Function(const Tensor* input, const Tensor* index,
}
}

template <typename T, typename U, typename V>
template <typename T, typename U>
void GatherV2GradFunction(const Tensor* input, const Tensor* index,
const Tensor* axis, Tensor* out,
const int axis, Tensor* out,
const paddle::platform::Place& place) {
auto* axis_data = axis->data<V>();
auto* index_data = index->data<U>();

int axis_size = axis->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();

if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1,
platform::errors::InvalidArgument(
"Axis size should be 1, but received %d", axis_size));
int axis_index = axis_data[0];
int axis_index = axis;
int input_index_dim_size = input_dim[axis_index];

int inner_dim_size = 1;
Expand Down
33 changes: 28 additions & 5 deletions paddle/fluid/operators/gather_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace operators {

Expand Down Expand Up @@ -52,11 +53,29 @@ class GatherOp : public framework::OperatorWithKernel {
index_dims.size()));
}

int batch_size = ctx->GetInputDim("Index")[0];
framework::DDim output_dims(ctx->GetInputDim("X"));
output_dims[0] = batch_size;
ctx->SetOutputDim("Out", output_dims);
ctx->ShareLoD("X", /*->*/ "Out");
auto axis = ctx->Attrs().Get<int>("axis");
auto input_dim = ctx->GetInputDim("X");
if (ctx->HasInput("Axis") || axis == 0) {
// if HasInput("Axis"), we can not obtain correct shape of output
int batch_size = index_dims[0];
framework::DDim output_dims(input_dim);
output_dims[0] = batch_size;
ctx->SetOutputDim("Out", output_dims);
ctx->ShareLoD("X", /*->*/ "Out");
} else {
int index_size = index_dims[0];
std::vector<int> out_dim_vec;
for (int i = 0; i < axis; i++) {
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
for (int i = axis + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
}
auto output_dims = framework::make_ddim(out_dim_vec);
ctx->SetOutputDim("Out", output_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
}

protected:
Expand Down Expand Up @@ -120,6 +139,10 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
"If true, update the grad using the overwrite mode in same index,"
"If false, using the accumulate mode in same index.")
.SetDefault(true);
AddAttr<int>(
"axis",
"The Tensor which contains the axis that we do gather operation.")
.SetDefault(0);
AddComment(R"DOC(
Gather Operator.
Expand Down
108 changes: 39 additions & 69 deletions paddle/fluid/operators/gather_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,47 +31,33 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");

int axis = ctx.Attr<int>("axis");

// get axis from tensor
if (ctx.HasInput("Axis")) {
const Tensor *axis = ctx.Input<Tensor>("Axis");
const auto &index_type = index->type();
const auto &axis_type = axis->type();
auto place = ctx.GetPlace();
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2CUDAFunction<T, int32_t, int32_t>(x, index, axis, output, place,
ctx);
}
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2CUDAFunction<T, int32_t, int64_t>(x, index, axis, output, place,
ctx);
Tensor cpu_axis;
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis);
const auto &axis_type = axis_tensor->type();
if (axis_type == framework::proto::VarType::INT32) {
axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT64) {
axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
}
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2CUDAFunction<T, int64_t, int32_t>(x, index, axis, output, place,
ctx);
}
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2CUDAFunction<T, int64_t, int64_t>(x, index, axis, output, place,
ctx);
}
const auto &place = ctx.GetPlace();
const auto &index_type = index->type();
if (axis != 0) {
if (index_type == framework::proto::VarType::INT32) {
GatherV2CUDAFunction<T, int32_t>(x, index, axis, output, place, ctx);
} else if (index_type == framework::proto::VarType::INT64) {
GatherV2CUDAFunction<T, int64_t>(x, index, axis, output, place, ctx);
}
return;
}

output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return;
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
GPUGather<T, int>(ctx.device_context(), *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) {
Expand All @@ -91,30 +77,27 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));

int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("Axis")) {
const Tensor *axis = ctx.Input<Tensor>("Axis");
const auto &index_type = index->type();
const auto &axis_type = axis->type();
auto place = ctx.GetPlace();
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2GradCUDAFunction<T, int32_t, int32_t>(dO, index, axis, dX,
place, ctx);
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
Tensor cpu_axis;
framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis);
const auto &axis_type = axis_tensor->type();
if (axis_type == framework::proto::VarType::INT32) {
axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT64) {
axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
}
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2GradCUDAFunction<T, int32_t, int64_t>(dO, index, axis, dX,
place, ctx);
}
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2GradCUDAFunction<T, int64_t, int32_t>(dO, index, axis, dX,
place, ctx);
}
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2GradCUDAFunction<T, int64_t, int64_t>(dO, index, axis, dX,
place, ctx);
}

const auto &index_type = index->type();
if (axis != 0) {
if (index_type == framework::proto::VarType::INT32) {
GatherV2GradCUDAFunction<T, int32_t>(dO, index, axis, dX,
ctx.GetPlace(), ctx);
} else if (index_type == framework::proto::VarType::INT64) {
GatherV2GradCUDAFunction<T, int64_t>(dO, index, axis, dX,
ctx.GetPlace(), ctx);
}
return;
}
Expand All @@ -125,19 +108,6 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;

const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
GPUScatterAssign<T, int>(ctx, *dO, *index, dX,
ctx.Attr<bool>("overwrite"));
Expand Down
Loading

0 comments on commit a4e841e

Please sign in to comment.