Skip to content

Commit

Permalink
Merge pull request #31 from jiaoxuewu/paddlebox
Browse files Browse the repository at this point in the history
index_select
  • Loading branch information
qingshui authored Jan 6, 2022
2 parents 075b552 + 495be05 commit f39b582
Show file tree
Hide file tree
Showing 10 changed files with 583 additions and 262 deletions.
85 changes: 85 additions & 0 deletions paddle/fluid/operators/gather.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,5 +297,90 @@ void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index,
input_data, index_data, out_data, outer_dim_size, inner_dim_size,
input_index_dim_size, out_index_dim_size, input_size);
}

template <typename T, typename U>
void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
const int axis, Tensor* out,
const paddle::platform::Place& place,
const framework::ExecutionContext& ctx) {
int64_t index_size = index->numel();
int64_t input_size = input->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();
auto* index_data = index->data<U>();

if (input->numel() == 0) return;

int axis_index = axis;
int64_t index_dim_size = input_dim[axis_index];

int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
std::vector<int64_t> out_dim_vec;

for (int i = 0; i < axis_index; i++) {
inner_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
for (int i = axis_index + 1; i < input_dim.size(); i++) {
outer_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
}
auto out_dim = framework::make_ddim(out_dim_vec);

out->Resize(out_dim);
auto* out_data = out->mutable_data<T>(place);
int64_t out_size = out->numel();
if (out_size == 0) return;

platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), out_size);
auto stream = ctx.cuda_device_context().stream();
GatherGPUKernel<
T, U><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
input_data, index_data, out_data, outer_dim_size, inner_dim_size,
index_size, index_dim_size, out_size);
}

template <typename T, typename U>
void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index,
const int axis, Tensor* out,
const paddle::platform::Place& place,
const framework::ExecutionContext& ctx) {
auto* index_data = index->data<U>();
int64_t index_size = index->numel();
int64_t input_size = input->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();

if (input->numel() == 0) return;
int axis_index = axis;
int64_t input_index_dim_size = input_dim[axis_index];

int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;

for (int i = 0; i < axis_index; i++) {
inner_dim_size *= input_dim[i];
}
for (int i = axis_index + 1; i < input_dim.size(); i++) {
outer_dim_size *= input_dim[i];
}

auto* out_data = out->mutable_data<T>(place);
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto out_dim = out->dims();
int64_t out_index_dim_size = out_dim[axis_index];
operators::math::set_constant(*dev_ctx, out, 0.0);

platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), input_size);
auto stream = ctx.cuda_device_context().stream();
GatherGradGPUKernel<
T, U><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
input_data, index_data, out_data, outer_dim_size, inner_dim_size,
input_index_dim_size, out_index_dim_size, input_size);
}
} // namespace operators
} // namespace paddle
99 changes: 99 additions & 0 deletions paddle/fluid/operators/gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,5 +231,104 @@ void GatherV2GradFunction(const Tensor* input, const Tensor* index,
}
}

template <typename T, typename U>
void GatherV2Function(const Tensor* input, const Tensor* index, int axis,
Tensor* out, const paddle::platform::Place& place) {
auto* index_data = index->data<U>();
int64_t index_size = index->numel();
int64_t input_size = input->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();

if (input->numel() == 0) return;
int axis_index = axis;

int64_t input_index_dim_size = input_dim[axis_index];
for (int64_t i = 0; i < index_size; i++) {
PADDLE_ENFORCE_LT(index_data[i], input_index_dim_size,
platform::errors::OutOfRange(
"The element of Index must be less than the size of "
"input dim size of axis which is %d, but received "
"index element which is %d in the %d index.",
input_index_dim_size, index_data[i], i));
PADDLE_ENFORCE_GE(index_data[i], 0,
platform::errors::OutOfRange(
"The element of Index must be greater than or equal "
"to 0, but received index element which is %d in the "
"%d index.",
index_data[i], i));
}

int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
std::vector<int64_t> out_dim_vec;

for (int i = 0; i < axis_index; i++) {
inner_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
for (int i = axis_index + 1; i < input_dim.size(); i++) {
outer_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
}
auto out_dim = framework::make_ddim(out_dim_vec);

out->Resize(out_dim);
auto* out_data = out->mutable_data<T>(place);

int out_index = 0;
for (int64_t i = 0; i < inner_dim_size; i++) {
for (int64_t j = 0; j < index_size; j++) {
for (int64_t k = 0; k < outer_dim_size; k++) {
int64_t index = k + index_data[j] * outer_dim_size +
(i * input_size / inner_dim_size);
out_data[out_index] = input_data[index];
out_index++;
}
}
}
}

template <typename T, typename U>
void GatherV2GradFunction(const Tensor* input, const Tensor* index,
const int axis, Tensor* out,
const paddle::platform::Place& place) {
auto* index_data = index->data<U>();

auto input_dim = input->dims();
auto* input_data = input->data<T>();

if (input->numel() == 0) return;
int axis_index = axis;
int64_t input_index_dim_size = input_dim[axis_index];

int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;

for (int i = 0; i < axis_index; i++) {
inner_dim_size *= input_dim[i];
}
for (int i = axis_index + 1; i < input_dim.size(); i++) {
outer_dim_size *= input_dim[i];
}

auto* out_data = out->mutable_data<T>(place);
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto out_dim = out->dims();
int64_t out_index_dim_size = out_dim[axis_index];
operators::math::set_constant(*dev_ctx, out, 0.0);

for (int64_t i = 0; i < inner_dim_size; i++) {
for (int64_t j = 0; j < input_index_dim_size; j++) {
for (int64_t k = 0; k < outer_dim_size; k++) {
int64_t index = k + index_data[j] * outer_dim_size +
i * outer_dim_size * out_index_dim_size;
out_data[index] += input_data[j * outer_dim_size + k];
}
}
}
}

} // namespace operators
} // namespace paddle
44 changes: 28 additions & 16 deletions paddle/fluid/operators/gather_op.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand All @@ -18,6 +15,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace operators {

Expand Down Expand Up @@ -52,11 +50,29 @@ class GatherOp : public framework::OperatorWithKernel {
index_dims.size()));
}

int batch_size = ctx->GetInputDim("Index")[0];
framework::DDim output_dims(ctx->GetInputDim("X"));
output_dims[0] = batch_size;
ctx->SetOutputDim("Out", output_dims);
ctx->ShareLoD("X", /*->*/ "Out");
auto axis = ctx->Attrs().Get<int>("axis");
auto input_dim = ctx->GetInputDim("X");
if (ctx->HasInput("Axis") || axis == 0) {
// if HasInput("Axis"), we can not obtain correct shape of output
int batch_size = index_dims[0];
framework::DDim output_dims(input_dim);
output_dims[0] = batch_size;
ctx->SetOutputDim("Out", output_dims);
ctx->ShareLoD("X", /*->*/ "Out");
} else {
int index_size = index_dims[0];
std::vector<int> out_dim_vec;
for (int i = 0; i < axis; i++) {
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
for (int i = axis + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
}
auto output_dims = framework::make_ddim(out_dim_vec);
ctx->SetOutputDim("Out", output_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
}

protected:
Expand Down Expand Up @@ -120,27 +136,23 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
"If true, update the grad using the overwrite mode in same index,"
"If false, using the accumulate mode in same index.")
.SetDefault(true);
AddAttr<int>(
"axis",
"The Tensor which contains the axis that we do gather operation.")
.SetDefault(0);
AddComment(R"DOC(
Gather Operator.
$Out = X[Index]$
Out is obtained by gathering entries of the outer-most dimension
of X indexed by Index and concatenate them together.
Example:
X = [[1, 2],
[3, 4],
[5, 6]]
Index = [[1, 2]]
Then:
Out = [[3, 4],
[5, 6]]
)DOC");
}
};
Expand Down
Loading

0 comments on commit f39b582

Please sign in to comment.