From 27c8b443741f4e681453665ec457e2379774fc35 Mon Sep 17 00:00:00 2001 From: zjhyj Date: Wed, 5 Jan 2022 17:41:38 +0800 Subject: [PATCH] index_select_gather --- paddle/fluid/operators/gather.cu.h | 85 ++++++++++ paddle/fluid/operators/gather.h | 99 +++++++++++ paddle/fluid/operators/gather_op.cc | 44 +++-- paddle/fluid/operators/gather_op.cu | 111 +++++------- paddle/fluid/operators/gather_op.h | 95 ++++------- paddle/fluid/operators/index_select_op.cc | 5 +- paddle/fluid/operators/index_select_op.cu | 196 +++++++++++++++++++++- paddle/fluid/operators/index_select_op.h | 155 +++++++++-------- python/paddle/tensor/manipulation.py | 50 +++--- python/paddle/tensor/search.py | 5 - 10 files changed, 583 insertions(+), 262 deletions(-) diff --git a/paddle/fluid/operators/gather.cu.h b/paddle/fluid/operators/gather.cu.h index 94fe45dac0ce7..194157f6ae5ff 100644 --- a/paddle/fluid/operators/gather.cu.h +++ b/paddle/fluid/operators/gather.cu.h @@ -297,5 +297,90 @@ void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index, input_data, index_data, out_data, outer_dim_size, inner_dim_size, input_index_dim_size, out_index_dim_size, input_size); } + +template +void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, + const int axis, Tensor* out, + const paddle::platform::Place& place, + const framework::ExecutionContext& ctx) { + int64_t index_size = index->numel(); + int64_t input_size = input->numel(); + auto input_dim = input->dims(); + auto* input_data = input->data(); + auto* index_data = index->data(); + + if (input->numel() == 0) return; + + int axis_index = axis; + int64_t index_dim_size = input_dim[axis_index]; + + int64_t inner_dim_size = 1; + int64_t outer_dim_size = 1; + std::vector out_dim_vec; + + for (int i = 0; i < axis_index; i++) { + inner_dim_size *= input_dim[i]; + out_dim_vec.push_back(input_dim[i]); + } + out_dim_vec.push_back(index_size); + for (int i = axis_index + 1; i < input_dim.size(); i++) { + outer_dim_size *= input_dim[i]; + out_dim_vec.push_back(input_dim[i]); + } + auto out_dim = framework::make_ddim(out_dim_vec); + + out->Resize(out_dim); + auto* out_data = out->mutable_data(place); + int64_t out_size = out->numel(); + if (out_size == 0) return; + + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), out_size); + auto stream = ctx.cuda_device_context().stream(); + GatherGPUKernel< + T, U><<>>( + input_data, index_data, out_data, outer_dim_size, inner_dim_size, + index_size, index_dim_size, out_size); +} + +template +void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index, + const int axis, Tensor* out, + const paddle::platform::Place& place, + const framework::ExecutionContext& ctx) { + auto* index_data = index->data(); + int64_t index_size = index->numel(); + int64_t input_size = input->numel(); + auto input_dim = input->dims(); + auto* input_data = input->data(); + + if (input->numel() == 0) return; + int axis_index = axis; + int64_t input_index_dim_size = input_dim[axis_index]; + + int64_t inner_dim_size = 1; + int64_t outer_dim_size = 1; + + for (int i = 0; i < axis_index; i++) { + inner_dim_size *= input_dim[i]; + } + for (int i = axis_index + 1; i < input_dim.size(); i++) { + outer_dim_size *= input_dim[i]; + } + + auto* out_data = out->mutable_data(place); + auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); + auto out_dim = out->dims(); + int64_t out_index_dim_size = out_dim[axis_index]; + operators::math::set_constant(*dev_ctx, out, 0.0); + + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), input_size); + auto stream = ctx.cuda_device_context().stream(); + GatherGradGPUKernel< + T, U><<>>( + input_data, index_data, out_data, outer_dim_size, inner_dim_size, + input_index_dim_size, out_index_dim_size, input_size); +} } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/gather.h b/paddle/fluid/operators/gather.h index c12a3b8adc978..c4aaeda741d3f 100644 --- a/paddle/fluid/operators/gather.h +++ b/paddle/fluid/operators/gather.h @@ -231,5 +231,104 @@ void GatherV2GradFunction(const Tensor* input, const Tensor* index, } } +template +void GatherV2Function(const Tensor* input, const Tensor* index, int axis, + Tensor* out, const paddle::platform::Place& place) { + auto* index_data = index->data(); + int64_t index_size = index->numel(); + int64_t input_size = input->numel(); + auto input_dim = input->dims(); + auto* input_data = input->data(); + + if (input->numel() == 0) return; + int axis_index = axis; + + int64_t input_index_dim_size = input_dim[axis_index]; + for (int64_t i = 0; i < index_size; i++) { + PADDLE_ENFORCE_LT(index_data[i], input_index_dim_size, + platform::errors::OutOfRange( + "The element of Index must be less than the size of " + "input dim size of axis which is %d, but received " + "index element which is %d in the %d index.", + input_index_dim_size, index_data[i], i)); + PADDLE_ENFORCE_GE(index_data[i], 0, + platform::errors::OutOfRange( + "The element of Index must be greater than or equal " + "to 0, but received index element which is %d in the " + "%d index.", + index_data[i], i)); + } + + int64_t inner_dim_size = 1; + int64_t outer_dim_size = 1; + std::vector out_dim_vec; + + for (int i = 0; i < axis_index; i++) { + inner_dim_size *= input_dim[i]; + out_dim_vec.push_back(input_dim[i]); + } + out_dim_vec.push_back(index_size); + for (int i = axis_index + 1; i < input_dim.size(); i++) { + outer_dim_size *= input_dim[i]; + out_dim_vec.push_back(input_dim[i]); + } + auto out_dim = framework::make_ddim(out_dim_vec); + + out->Resize(out_dim); + auto* out_data = out->mutable_data(place); + + int out_index = 0; + for (int64_t i = 0; i < inner_dim_size; i++) { + for (int64_t j = 0; j < index_size; j++) { + for (int64_t k = 0; k < outer_dim_size; k++) { + int64_t index = k + index_data[j] * outer_dim_size + + (i * input_size / inner_dim_size); + out_data[out_index] = input_data[index]; + out_index++; + } + } + } +} + +template +void GatherV2GradFunction(const Tensor* input, const Tensor* index, + const int axis, Tensor* out, + const paddle::platform::Place& place) { + auto* index_data = index->data(); + + auto input_dim = input->dims(); + auto* input_data = input->data(); + + if (input->numel() == 0) return; + int axis_index = axis; + int64_t input_index_dim_size = input_dim[axis_index]; + + int64_t inner_dim_size = 1; + int64_t outer_dim_size = 1; + + for (int i = 0; i < axis_index; i++) { + inner_dim_size *= input_dim[i]; + } + for (int i = axis_index + 1; i < input_dim.size(); i++) { + outer_dim_size *= input_dim[i]; + } + + auto* out_data = out->mutable_data(place); + auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); + auto out_dim = out->dims(); + int64_t out_index_dim_size = out_dim[axis_index]; + operators::math::set_constant(*dev_ctx, out, 0.0); + + for (int64_t i = 0; i < inner_dim_size; i++) { + for (int64_t j = 0; j < input_index_dim_size; j++) { + for (int64_t k = 0; k < outer_dim_size; k++) { + int64_t index = k + index_data[j] * outer_dim_size + + i * outer_dim_size * out_index_dim_size; + out_data[index] += input_data[j * outer_dim_size + k]; + } + } + } +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 162766546b3c2..374776c7806d4 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,6 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/op_version_registry.h" + namespace paddle { namespace operators { @@ -52,11 +50,29 @@ class GatherOp : public framework::OperatorWithKernel { index_dims.size())); } - int batch_size = ctx->GetInputDim("Index")[0]; - framework::DDim output_dims(ctx->GetInputDim("X")); - output_dims[0] = batch_size; - ctx->SetOutputDim("Out", output_dims); - ctx->ShareLoD("X", /*->*/ "Out"); + auto axis = ctx->Attrs().Get("axis"); + auto input_dim = ctx->GetInputDim("X"); + if (ctx->HasInput("Axis") || axis == 0) { + // if HasInput("Axis"), we can not obtain correct shape of output + int batch_size = index_dims[0]; + framework::DDim output_dims(input_dim); + output_dims[0] = batch_size; + ctx->SetOutputDim("Out", output_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } else { + int index_size = index_dims[0]; + std::vector out_dim_vec; + for (int i = 0; i < axis; i++) { + out_dim_vec.push_back(input_dim[i]); + } + out_dim_vec.push_back(index_size); + for (int i = axis + 1; i < input_dim.size(); i++) { + out_dim_vec.push_back(input_dim[i]); + } + auto output_dims = framework::make_ddim(out_dim_vec); + ctx->SetOutputDim("Out", output_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } } protected: @@ -120,27 +136,23 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker { "If true, update the grad using the overwrite mode in same index," "If false, using the accumulate mode in same index.") .SetDefault(true); + AddAttr( + "axis", + "The Tensor which contains the axis that we do gather operation.") + .SetDefault(0); AddComment(R"DOC( Gather Operator. - $Out = X[Index]$ - Out is obtained by gathering entries of the outer-most dimension of X indexed by Index and concatenate them together. - Example: - X = [[1, 2], [3, 4], [5, 6]] - Index = [[1, 2]] - Then: - Out = [[3, 4], [5, 6]] - )DOC"); } }; diff --git a/paddle/fluid/operators/gather_op.cu b/paddle/fluid/operators/gather_op.cu index 37fbfb21f60a0..2722a356aeddf 100644 --- a/paddle/fluid/operators/gather_op.cu +++ b/paddle/fluid/operators/gather_op.cu @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -31,47 +28,33 @@ class GatherOpCUDAKernel : public framework::OpKernel { auto *index = ctx.Input("Index"); auto *output = ctx.Output("Out"); + int axis = ctx.Attr("axis"); + + // get axis from tensor if (ctx.HasInput("Axis")) { - const Tensor *axis = ctx.Input("Axis"); - const auto &index_type = index->type(); - const auto &axis_type = axis->type(); - auto place = ctx.GetPlace(); - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT32) { - GatherV2CUDAFunction(x, index, axis, output, place, - ctx); - } - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT64) { - GatherV2CUDAFunction(x, index, axis, output, place, - ctx); + Tensor cpu_axis; + const Tensor *axis_tensor = ctx.Input("Axis"); + framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis); + const auto &axis_type = axis_tensor->type(); + if (axis_type == framework::proto::VarType::INT32) { + axis = static_cast(cpu_axis.data()[0]); + } else if (axis_type == framework::proto::VarType::INT64) { + axis = static_cast(cpu_axis.data()[0]); } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT32) { - GatherV2CUDAFunction(x, index, axis, output, place, - ctx); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT64) { - GatherV2CUDAFunction(x, index, axis, output, place, - ctx); + } + const auto &place = ctx.GetPlace(); + const auto &index_type = index->type(); + if (axis != 0) { + if (index_type == framework::proto::VarType::INT32) { + GatherV2CUDAFunction(x, index, axis, output, place, ctx); + } else if (index_type == framework::proto::VarType::INT64) { + GatherV2CUDAFunction(x, index, axis, output, place, ctx); } return; } + output->mutable_data(ctx.GetPlace()); if (x->numel() == 0) return; - const auto &index_type = index->type(); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { GPUGather(ctx.device_context(), *x, *index, output); } else if (index_type == framework::proto::VarType::INT64) { @@ -91,30 +74,27 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { auto *dX = ctx.Output(framework::GradVarName("X")); auto *dO = ctx.Input(framework::GradVarName("Out")); + int axis = ctx.Attr("axis"); if (ctx.HasInput("Axis")) { - const Tensor *axis = ctx.Input("Axis"); - const auto &index_type = index->type(); - const auto &axis_type = axis->type(); - auto place = ctx.GetPlace(); - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT32) { - GatherV2GradCUDAFunction(dO, index, axis, dX, - place, ctx); - } - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT64) { - GatherV2GradCUDAFunction(dO, index, axis, dX, - place, ctx); + const Tensor *axis_tensor = ctx.Input("Axis"); + Tensor cpu_axis; + framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis); + const auto &axis_type = axis_tensor->type(); + if (axis_type == framework::proto::VarType::INT32) { + axis = static_cast(cpu_axis.data()[0]); + } else if (axis_type == framework::proto::VarType::INT64) { + axis = static_cast(cpu_axis.data()[0]); } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT32) { - GatherV2GradCUDAFunction(dO, index, axis, dX, - place, ctx); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT64) { - GatherV2GradCUDAFunction(dO, index, axis, dX, - place, ctx); + } + + const auto &index_type = index->type(); + if (axis != 0) { + if (index_type == framework::proto::VarType::INT32) { + GatherV2GradCUDAFunction(dO, index, axis, dX, + ctx.GetPlace(), ctx); + } else if (index_type == framework::proto::VarType::INT64) { + GatherV2GradCUDAFunction(dO, index, axis, dX, + ctx.GetPlace(), ctx); } return; } @@ -125,19 +105,6 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { .eigen_device(); dxt.device(place) = dxt.constant(static_cast(0)); if (dO->numel() == 0) return; - - const auto &index_type = index->type(); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { GPUScatterAssign(ctx, *dO, *index, dX, ctx.Attr("overwrite")); diff --git a/paddle/fluid/operators/gather_op.h b/paddle/fluid/operators/gather_op.h index 8ec0d6ce0b69c..bda7e137e8de0 100644 --- a/paddle/fluid/operators/gather_op.h +++ b/paddle/fluid/operators/gather_op.h @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -35,45 +32,30 @@ class GatherOpKernel : public framework::OpKernel { auto *index = ctx.Input("Index"); auto *output = ctx.Output("Out"); + int axis = ctx.Attr("axis"); + // get axis from tensor if (ctx.HasInput("Axis")) { - const Tensor *axis = ctx.Input("Axis"); - const auto &index_type = index->type(); - const auto &axis_type = axis->type(); - auto place = ctx.GetPlace(); - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT32) { - GatherV2Function(x, index, axis, output, place); - } - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT64) { - GatherV2Function(x, index, axis, output, place); + const Tensor *axis_tensor = ctx.Input("Axis"); + const auto &axis_type = axis_tensor->type(); + if (axis_type == framework::proto::VarType::INT32) { + axis = static_cast(axis_tensor->data()[0]); + } else if (axis_type == framework::proto::VarType::INT64) { + axis = static_cast(axis_tensor->data()[0]); } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT32) { - GatherV2Function(x, index, axis, output, place); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT64) { - GatherV2Function(x, index, axis, output, place); + } + const auto &place = ctx.GetPlace(); + const auto &index_type = index->type(); + if (axis != 0) { + if (index_type == framework::proto::VarType::INT32) { + GatherV2Function(x, index, axis, output, place); + } else if (index_type == framework::proto::VarType::INT64) { + GatherV2Function(x, index, axis, output, place); } return; } output->mutable_data(ctx.GetPlace()); if (x->numel() == 0) return; - - const auto &index_type = index->type(); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { CPUGather(ctx.device_context(), *x, *index, output); } else if (index_type == framework::proto::VarType::INT64) { @@ -94,26 +76,23 @@ class GatherGradientOpKernel : public framework::OpKernel { auto *dX = ctx.Output(framework::GradVarName("X")); auto *dO = ctx.Input(framework::GradVarName("Out")); + int axis = ctx.Attr("axis"); if (ctx.HasInput("Axis")) { - const Tensor *axis = ctx.Input("Axis"); - const auto &index_type = index->type(); - const auto &axis_type = axis->type(); - auto place = ctx.GetPlace(); - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT32) { - GatherV2GradFunction(dO, index, axis, dX, place); - } - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT64) { - GatherV2GradFunction(dO, index, axis, dX, place); + const Tensor *axis_tensor = ctx.Input("Axis"); + const auto &axis_type = axis_tensor->type(); + if (axis_type == framework::proto::VarType::INT32) { + axis = static_cast(axis_tensor->data()[0]); + } else if (axis_type == framework::proto::VarType::INT64) { + axis = static_cast(axis_tensor->data()[0]); } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT32) { - GatherV2GradFunction(dO, index, axis, dX, place); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT64) { - GatherV2GradFunction(dO, index, axis, dX, place); + } + const auto &index_type = index->type(); + + if (axis != 0) { + if (index_type == framework::proto::VarType::INT32) { + GatherV2GradFunction(dO, index, axis, dX, ctx.GetPlace()); + } else if (index_type == framework::proto::VarType::INT64) { + GatherV2GradFunction(dO, index, axis, dX, ctx.GetPlace()); } return; } @@ -126,18 +105,6 @@ class GatherGradientOpKernel : public framework::OpKernel { if (dO->numel() == 0) return; bool overwrite = ctx.Attr("overwrite"); - const auto &index_type = index->type(); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { if (overwrite) { ScatterAssign(ctx.device_context(), *dO, *index, dX); diff --git a/paddle/fluid/operators/index_select_op.cc b/paddle/fluid/operators/index_select_op.cc index 60ca7e2fe7cfd..d2e5aaebad186 100644 --- a/paddle/fluid/operators/index_select_op.cc +++ b/paddle/fluid/operators/index_select_op.cc @@ -54,6 +54,10 @@ class IndexSelectOp : public framework::OperatorWithKernel { "the dimension of Input(Index) is [%d].", index_dim, index_dim.size())); + PADDLE_ENFORCE_EQ(index_dim[0] != 0, true, + platform::errors::InvalidArgument( + "The length of Input(Index) can't be 0.")); + auto output_dim = framework::vectorize(input_dim); if (dim < 0) { dim += input_dim.size(); @@ -112,7 +116,6 @@ class IndexSelectOpMaker : public framework::OpProtoAndCheckerMaker { Returns a new tensor which indexes the input tensor along dimension dim using the entries in index which is a Tensor. - The returned tensor has the same number of dimensions as the original tensor (input). The dim-th dimension has the same size as the length of index; other dimensions diff --git a/paddle/fluid/operators/index_select_op.cu b/paddle/fluid/operators/index_select_op.cu index 36a91d98a2ade..168e56bf111cf 100644 --- a/paddle/fluid/operators/index_select_op.cu +++ b/paddle/fluid/operators/index_select_op.cu @@ -12,18 +12,198 @@ // See the License for the specific language governing permissions and // limitations under the License. +#pragma once +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/index_select_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +__global__ void index_select_cuda_kernel(const T* input, T* output, + const IndexT* index, int64_t N, + int64_t stride, int64_t size, + int64_t delta) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; + } + + int64_t pre_idx = idx / (stride * size); + int64_t dim_idx = idx % (stride * size) / stride; + IndexT src_dim_idx = index[dim_idx]; + int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; + output[idx] = input[input_idx]; +} + +template +__global__ void index_select_grad_cuda_kernel(const T* output_grad, + T* input_grad, + const IndexT* index, int64_t nums, + int64_t N, int64_t stride, + int64_t size, int64_t delta) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; + } + + int64_t pre_idx = idx / (stride * size); + int64_t dim_idx = idx % (stride * size) / stride; + IndexT src_dim_idx = index[dim_idx]; + int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; + paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]); +} + +template +__global__ void index_select_grad_init(T* input_grad, int64_t N) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; + } + input_grad[idx] = 0.0; +} + +template +class IndexSelectCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* index = context.Input("Index"); + auto* out = context.Output("Out"); + int dim = context.Attr("dim"); + auto input_dim = in->dims(); + auto output_dim = out->dims(); + dim = dim >= 0 ? dim : dim + input_dim.size(); + auto stride_dim = framework::stride(input_dim); + int64_t stride = stride_dim[dim]; + int64_t size = output_dim[dim]; + int64_t delta = input_dim[dim] - size; + + const auto& index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT64 || + index_type == framework::proto::VarType::INT32; + PADDLE_ENFORCE_EQ(index_type_match, true, + platform::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + + auto* in_data = in->data(); + auto* out_data = out->mutable_data(context.GetPlace()); + int64_t numel = out->numel(); + + auto stream = + context.template device_context().stream(); + + if (index_type == framework::proto::VarType::INT64) { + const int64_t* index_data = index->data(); + index_select_cuda_kernel<<< + (numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_data, out_data, index_data, + numel, stride, size, delta); + platform::GpuStreamSync(stream); + } else { + const int* index_data = index->data(); + index_select_cuda_kernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / + PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + in_data, out_data, index_data, numel, stride, size, delta); + platform::GpuStreamSync(stream); + } + } +}; + +template +class IndexSelectGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* output_grad = context.Input(framework::GradVarName("Out")); + auto* in_grad = context.Output(framework::GradVarName("X")); + auto* index = context.Input("Index"); + + auto* output_grad_data = output_grad->data(); + auto* in_grad_data = in_grad->mutable_data(context.GetPlace()); + + int dim = context.Attr("dim"); + auto input_dim = in_grad->dims(); + auto output_dim = output_grad->dims(); + dim = dim >= 0 ? dim : dim + input_dim.size(); + auto stride_dim = framework::stride(input_dim); + int64_t stride = stride_dim[dim]; + int64_t size = output_dim[dim]; + int64_t delta = input_dim[dim] - size; + + const auto& index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT64 || + index_type == framework::proto::VarType::INT32; + PADDLE_ENFORCE_EQ(index_type_match, true, + platform::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + + int64_t numel = in_grad->numel(); + int64_t index_nums = index->numel(); + int64_t out_nums = output_grad->numel(); + + auto stream = + context.template device_context().stream(); + + index_select_grad_init< + T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_grad_data, numel); + + if (index_type == framework::proto::VarType::INT64) { + const int64_t* index_data = index->data(); + index_select_grad_cuda_kernel<<< + (out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data, + index_data, index_nums, + out_nums, stride, size, delta); + platform::GpuStreamSync(stream); + } else { + const int* index_data = index->data(); + index_select_grad_cuda_kernel<<< + (out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data, + index_data, index_nums, + out_nums, stride, size, delta); + platform::GpuStreamSync(stream); + } + } +}; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( index_select, - ops::IndexSelectKernel, - ops::IndexSelectKernel, - ops::IndexSelectKernel, - ops::IndexSelectKernel); + ops::IndexSelectCUDAKernel, + ops::IndexSelectCUDAKernel, + ops::IndexSelectCUDAKernel, + ops::IndexSelectCUDAKernel, + ops::IndexSelectCUDAKernel); REGISTER_OP_CUDA_KERNEL( index_select_grad, - ops::IndexSelectGradKernel, - ops::IndexSelectGradKernel, - ops::IndexSelectGradKernel, - ops::IndexSelectGradKernel); + ops::IndexSelectGradCUDAKernel, + ops::IndexSelectGradCUDAKernel, + ops::IndexSelectGradCUDAKernel, + ops::IndexSelectGradCUDAKernel, + ops::IndexSelectGradCUDAKernel); diff --git a/paddle/fluid/operators/index_select_op.h b/paddle/fluid/operators/index_select_op.h index 70714b7f3a064..be76a66ef7c96 100644 --- a/paddle/fluid/operators/index_select_op.h +++ b/paddle/fluid/operators/index_select_op.h @@ -15,6 +15,8 @@ #pragma once #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { @@ -23,71 +25,69 @@ using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; using DDim = framework::DDim; -template +template void IndexSelectInner(const framework::ExecutionContext& context, - const LoDTensor& input, const LoDTensor& index, + LoDTensor* input, const LoDTensor& index, LoDTensor* output, int dim) { - auto input_dim = input.dims(); + auto input_dim = input->dims(); auto input_dim_size = input_dim.size(); auto output_dim = output->dims(); + auto index_size = index.dims()[0]; + + LoDTensor index_cpu_copy; + if (!platform::is_cpu_place(index.place())) { + framework::TensorCopySync(index, platform::CPUPlace(), &index_cpu_copy); + } + const IndexT* index_data = platform::is_cpu_place(index.place()) + ? index.data() + : index_cpu_copy.data(); + output->mutable_data(context.GetPlace()); auto slice_size = 1; for (auto i = dim + 1; i < input_dim_size; i++) { slice_size *= input_dim[i]; } - auto input_width = slice_size * input_dim[dim]; - auto output_width = slice_size * output_dim[dim]; - auto outer_nums = 1; for (auto i = 0; i < dim; i++) { outer_nums *= input_dim[i]; } - auto index_size = index.dims()[0]; - - std::vector input_vec; - std::vector index_vec; - TensorToVector(input, context.device_context(), &input_vec); - TensorToVector(index, context.device_context(), &index_vec); - std::vector out_vec(output->numel()); - for (int i = 0; i < index_size; i++) { PADDLE_ENFORCE_GE( - index_vec[i], 0, + index_data[i], 0, platform::errors::InvalidArgument( "Variable value (index) of OP(index_select) " "expected >= 0 and < %ld, but got %ld. Please check input " "value.", - input_dim[dim], index_vec[i])); + input_dim[dim], index_data[i])); PADDLE_ENFORCE_LT( - index_vec[i], input_dim[dim], + index_data[i], input_dim[dim], platform::errors::InvalidArgument( "Variable value (index) of OP(index_select) " "expected >= 0 and < %ld, but got %ld. Please check input " "value.", - input_dim[dim], index_vec[i])); + input_dim[dim], index_data[i])); } VLOG(3) << "Index_Select_Debug; outer_nums: " << outer_nums - << "; slice_size: " << slice_size << "; input_width: " << input_width - << "; output_width: " << output_width - << "; index_size: " << index_size; + << "; slice_size: " << slice_size << "; index_size: " << index_size; - for (auto i = 0; i < outer_nums; i++) { - auto input_start_offset = i * input_width; - auto output_start_offset = i * output_width; + input->Resize(framework::make_ddim({outer_nums, input_dim[dim], slice_size})); + output->Resize(framework::make_ddim({outer_nums, index_size, slice_size})); - for (auto j = 0; j < index_size; j++) { - IndexT index_value = index_vec[j]; - for (auto k = 0; k < slice_size; k++) { - out_vec[output_start_offset + j * slice_size + k] = - input_vec[input_start_offset + index_value * slice_size + k]; - } - } + auto input_tensor = framework::EigenTensor::From(*input); + auto output_tensor = framework::EigenTensor::From(*output); + + auto& place = + *context.template device_context().eigen_device(); + + for (auto j = 0; j < index_size; j++) { + IndexT index_value = index_data[j]; + auto output_t = output_tensor.chip(j, 1); + output_t.device(place) = input_tensor.chip(index_value, 1); } - output->mutable_data(context.GetPlace()); - framework::TensorFromVector(out_vec, context.device_context(), output); + input->Resize(input_dim); output->Resize(output_dim); } @@ -95,20 +95,15 @@ template class IndexSelectKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* inputs_var = context.InputVar("X"); - auto* index_var = context.InputVar("Index"); - auto* output_var = context.OutputVar("Out"); - - auto& inputs = inputs_var->Get(); - auto& index = index_var->Get(); - auto* output = output_var->GetMutable(); + auto inputs = *context.Input("X"); + auto* index = context.Input("Index"); + auto* output = context.Output("Out"); int dim = context.Attr("dim"); if (dim < 0) { dim += inputs.dims().size(); } - - const auto& index_type = index.type(); + const auto& index_type = index->type(); bool index_type_match = index_type == framework::proto::VarType::INT32 || index_type == framework::proto::VarType::INT64; PADDLE_ENFORCE_EQ(index_type_match, true, @@ -122,26 +117,50 @@ class IndexSelectKernel : public framework::OpKernel { framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { - IndexSelectInner(context, inputs, index, output, dim); + IndexSelectInner(context, &inputs, *index, output, + dim); } else if (index_type == framework::proto::VarType::INT64) { - IndexSelectInner(context, inputs, index, output, dim); + IndexSelectInner(context, &inputs, *index, + output, dim); + } + } +}; + +template +struct IndexSelectAdd { + void operator()(const framework::ExecutionContext& ctx, int slice_size, + const T* src_pointer, const T* p_pointer, T* dist_pointer) { + for (int i = 0; i < slice_size; i++) { + dist_pointer[i] = src_pointer[i] + p_pointer[i]; } } }; +template +struct IndexSelectAdd< + DeviceContext, T, + typename std::enable_if::value>::type> { + void operator()(const framework::ExecutionContext& ctx, int slice_size, + const T* src_pointer, const T* p_pointer, T* dist_pointer) { + auto blas = math::GetBlas(ctx); + blas.VADD(slice_size, src_pointer, p_pointer, dist_pointer); + } +}; -template +template void IndexSelectGradInner(const framework::ExecutionContext& context, const LoDTensor& out_grad, const LoDTensor& index, LoDTensor* x_grad, int dim) { - std::vector input_vec; - std::vector index_vec; - TensorToVector(out_grad, context.device_context(), &input_vec); - TensorToVector(index, context.device_context(), &index_vec); - + const T* input_data = out_grad.data(); + const IndexT* index_data = index.data(); + const T* p_output = x_grad->mutable_data(context.GetPlace()); + T* out_data = x_grad->mutable_data(context.GetPlace()); auto input_dim = out_grad.dims(); auto input_dim_size = input_dim.size(); auto output_dim = x_grad->dims(); - std::vector out_vec(x_grad->numel(), 0); + + auto& dev_ctx = context.template device_context(); + math::SetConstant set_constant; + set_constant(dev_ctx, x_grad, static_cast(0.0)); auto slice_size = 1; for (auto i = dim + 1; i < input_dim_size; i++) { @@ -167,15 +186,14 @@ void IndexSelectGradInner(const framework::ExecutionContext& context, auto output_start_offset = i * output_width; for (auto j = 0; j < index_size; j++) { - IndexT index_value = index_vec[j]; - for (auto k = 0; k < slice_size; k++) { - out_vec[output_start_offset + index_value * slice_size + k] += - input_vec[input_start_offset + j * slice_size + k]; - } + IndexT index_value = index_data[j]; + auto src = input_data + input_start_offset + j * slice_size; + auto p_out = p_output + output_start_offset + index_value * slice_size; + auto dst = out_data + output_start_offset + index_value * slice_size; + IndexSelectAdd index_select_add; + index_select_add(context, slice_size, src, p_out, dst); } } - x_grad->mutable_data(context.GetPlace()); - framework::TensorFromVector(out_vec, context.device_context(), x_grad); x_grad->Resize(output_dim); } @@ -183,19 +201,18 @@ template class IndexSelectGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* index_var = context.InputVar("Index"); - auto* x_grad_var = context.OutputVar(framework::GradVarName("X")); - auto* out_grad_var = context.InputVar(framework::GradVarName("Out")); + auto* x_grad = + context.Output(framework::GradVarName("X")); + auto* index = context.Input("Index"); + auto* out_grad = + context.Input(framework::GradVarName("Out")); - auto& index = index_var->Get(); - auto& out_grad = out_grad_var->Get(); - auto* x_grad = x_grad_var->GetMutable(); int dim = context.Attr("dim"); if (dim < 0) { - dim += out_grad.dims().size(); + dim += out_grad->dims().size(); } + const auto& index_type = index->type(); - const auto& index_type = index.type(); bool index_type_match = index_type == framework::proto::VarType::INT32 || index_type == framework::proto::VarType::INT64; PADDLE_ENFORCE_EQ(index_type_match, true, @@ -209,9 +226,11 @@ class IndexSelectGradKernel : public framework::OpKernel { framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { - IndexSelectGradInner(context, out_grad, index, x_grad, dim); + IndexSelectGradInner(context, *out_grad, *index, + x_grad, dim); } else if (index_type == framework::proto::VarType::INT64) { - IndexSelectGradInner(context, out_grad, index, x_grad, dim); + IndexSelectGradInner(context, *out_grad, + *index, x_grad, dim); } } }; diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 40a8fdb7ef095..94aaa4330f26a 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -715,24 +715,16 @@ def gather(x, index, axis=None, name=None): """ Output is obtained by gathering entries of ``axis`` of ``x`` indexed by ``index`` and concatenate them together. - .. code-block:: text - - Given: - x = [[1, 2], [3, 4], [5, 6]] - index = [1, 2] axis=[0] - Then: - out = [[3, 4], [5, 6]] - Args: x (Tensor): The source input tensor with rank>=1. Supported data type is int32, int64, float32, float64 and uint8 (only for CPU), @@ -741,16 +733,12 @@ def gather(x, index, axis=None, name=None): axis (Tensor|int, optional): The axis of input to be gathered, it's can be int or a Tensor with data type is int32 or int64. The default value is None, if None, the ``axis`` is 0. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . - Returns: output (Tensor): The output is a tensor with the same rank as ``x``. Examples: - .. code-block:: python - import paddle - input = paddle.to_tensor([[1,2],[3,4],[5,6]]) index = paddle.to_tensor([0,1]) output = paddle.gather(input, index, axis=0) @@ -758,34 +746,40 @@ def gather(x, index, axis=None, name=None): """ if axis is None: axis = 0 - axis_tensor = axis - if not isinstance(axis, Variable) and axis == 0: - return paddle.fluid.layers.gather(input=x, index=index, overwrite=False) - if not isinstance(axis, Variable): - with device_guard("cpu"): - axis_tensor = fill_constant( - shape=[1], dtype='int64', value=axis, force_cpu=True) + if in_dygraph_mode(): - return core.ops.gather(x, index, axis_tensor) + #not support for in_dygraph_mode, please fix it + raise ValueError( + "not support for in_dygraph_mode, please fix it") check_variable_and_dtype( x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], 'gather') check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather') + if isinstance(axis, Variable): check_variable_and_dtype(axis, 'axis', ['int32', 'int64'], 'gather') - else: - check_type(axis, 'axis', (int), 'gather') helper = LayerHelper('gather', **locals()) dtype = helper.input_dtype('x') out = helper.create_variable_for_type_inference(dtype) - helper.append_op( - type="gather", - inputs={"X": x, - "Index": index, - "Axis": axis_tensor}, - outputs={"Out": out}) + if not isinstance(axis, Variable): + helper.append_op( + type="gather", + inputs={"X": x, + "Index": index}, + attrs={'axis': axis, + 'overwrite': False}, + outputs={"Out": out}) + else: + helper.append_op( + type="gather", + inputs={"X": x, + "Index": index, + "Axis": axis}, + attrs={"overwrite": False}, + outputs={"Out": out}) + return out diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 32f7bf373bbbd..7bdfa9858ae1a 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -274,12 +274,10 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None): def index_select(x, index, axis=0, name=None): """ - Returns a new tensor which indexes the ``input`` tensor along dimension ``axis`` using the entries in ``index`` which is a Tensor. The returned tensor has the same number of dimensions as the original ``x`` tensor. The dim-th dimension has the same size as the length of ``index``; other dimensions have the same size as in the ``x`` tensor. - Args: x (Tensor): The input Tensor to be operated. The data of ``x`` can be one of float32, float64, int32, int64. index (Tensor): The 1-D Tensor containing the indices to index. The data type of ``index`` must be int32 or int64. @@ -287,7 +285,6 @@ def index_select(x, index, axis=0, name=None): name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. - Returns: Tensor: A Tensor with same data type as ``x``. @@ -295,7 +292,6 @@ def index_select(x, index, axis=0, name=None): .. code-block:: python import paddle - x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]) @@ -329,7 +325,6 @@ def index_select(x, index, axis=0, name=None): attrs={'dim': axis}) return out - def nonzero(x, as_tuple=False): """ Return a tensor containing the indices of all non-zero elements of the `input`