Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add index_put api #52886

Merged
merged 28 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
21c5464
add index_put api
Courtesy-Xs Apr 13, 2023
a75ded8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Courtesy-Xs Apr 13, 2023
91c30e6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Courtesy-Xs Apr 13, 2023
9da71b6
fix some bugs
Courtesy-Xs Apr 14, 2023
4538c1a
fix value broadcast in backward and add test case in static
Courtesy-Xs Apr 16, 2023
244d02d
fix cpu backward bug
Courtesy-Xs Apr 17, 2023
01672f8
add timeout=120s for index_put
Courtesy-Xs Apr 17, 2023
5a361ea
add op_compat for index_put
Courtesy-Xs Apr 17, 2023
a7f2d42
delete input_put in op_compat.yaml
Courtesy-Xs Apr 17, 2023
d996d36
add inplace index_put test
Courtesy-Xs Apr 17, 2023
8a3fef4
refactor code
Courtesy-Xs Apr 18, 2023
5f77bb5
add test case when index tensor in indices is int32 when indices.size…
Courtesy-Xs Apr 18, 2023
6267d32
add index_put api backward in cpu place
Courtesy-Xs Apr 18, 2023
fdd0436
add backward test case
Courtesy-Xs Apr 18, 2023
86d6cac
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Courtesy-Xs Apr 28, 2023
7b71a3a
fix take in init.py bug
Courtesy-Xs Apr 28, 2023
48a03c6
refactor code according to review result
Courtesy-Xs May 6, 2023
9b2d455
alter 2022 to 2023 in copyright declaration
Courtesy-Xs May 6, 2023
0c6545a
refactor code to delete some duplicated code
Courtesy-Xs May 6, 2023
894adb1
replaace reshape with resize for decrease extra memcpy
Courtesy-Xs May 8, 2023
ed7a141
add datatype flag in backward yaml
Courtesy-Xs May 8, 2023
c92f75e
replace macro with template with conditional complilation
Courtesy-Xs May 8, 2023
4de9b48
fix rocmn bug
Courtesy-Xs May 9, 2023
ed00d81
fix note and rocmn bug
Courtesy-Xs May 9, 2023
f956aee
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Courtesy-Xs May 9, 2023
43167ab
fix conflict between flatten and index_put
Courtesy-Xs May 9, 2023
b09221f
fix bug in documentation
Courtesy-Xs May 9, 2023
db0209f
Update python/paddle/tensor/manipulation.py
Ligoml May 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,17 @@
data_type : out_grad
inplace : (out_grad -> x_grad)

- backward_op : index_put_grad
forward : index_put (Tensor x, Tensor[] indices, Tensor value, bool accumulate=false) -> Tensor(out)
args : (Tensor x, Tensor[] indices, Tensor value, Tensor out_grad, bool accumulate=false)
output : Tensor(x_grad), Tensor(value_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, value]
kernel :
func : index_put_grad
data_type : out_grad

- backward_op : index_sample_grad
forward : index_sample (Tensor x, Tensor index) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad)
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,17 @@
inplace : (x -> out)
backward : index_add_grad

- op : index_put
args : (Tensor x, Tensor[] indices, Tensor value, bool accumulate=false)
output : Tensor(out)
infer_meta :
func : IndexPutInferMeta
kernel :
func : index_put
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

输入x和indices的数据类型不同,需要指定按照谁的数据类型来选择kernel,关键字为data_type,写法如后面紧跟的index_sample

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

data_type : x
inplace : (x -> out)
backward : index_put_grad

- op : index_sample
args : (Tensor x, Tensor index)
output : Tensor
Expand Down
16 changes: 15 additions & 1 deletion paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1962,6 +1962,21 @@ void InterpolateInferMeta(
}
}

void IndexPutInferMeta(const MetaTensor& x,
const std::vector<const MetaTensor*>& indices,
const MetaTensor& value,
bool accumulate,
MetaTensor* out) {
auto in_dims = x.dims();
PADDLE_ENFORCE_LT(
in_dims.size(),
7,
phi::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.",
in_dims.size()));
out->share_meta(x);
}

void LambInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
Expand Down Expand Up @@ -3295,6 +3310,5 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
out_count->set_dims({-1});
out_count->set_dtype(DataType::INT32);
}

} // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,12 @@ void InterpolateInferMeta(
MetaTensor* output,
MetaConfig config = MetaConfig());

void IndexPutInferMeta(const MetaTensor& x,
const std::vector<const MetaTensor*>& indices,
const MetaTensor& value,
bool accumulate,
MetaTensor* out);

void LambInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
Expand Down
225 changes: 225 additions & 0 deletions paddle/phi/kernels/cpu/index_put_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/index_put_grad_kernel.h"
#include <numeric>
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/index_put_utils.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"

namespace phi {

template <typename T>
void set_zero_kernel(const int64_t N,
const int64_t** indices,
const phi::DDim& stride,
const phi::DDim& shape,
T* out) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t idx = 0; idx < N; ++idx) {
int64_t cur_ix = 0;
int64_t offset = 0;

for (int i = 0; i < shape.size(); ++i) {
cur_ix = (static_cast<int64_t>(*(indices[i] + idx)));
if (cur_ix < 0) {
cur_ix += shape[i];
}
offset += stride[i] * cur_ix;
}
*(out + offset) = 0;
}
}

template <typename T>
void index_put_grad_kernel(const int64_t N,
const T* out_grad,
const int64_t** indices,
const phi::DDim& stride,
const phi::DDim& shape,
T* value_grad) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t idx = 0; idx < N; ++idx) {
int64_t cur_ix = 0;
int64_t offset = 0;

for (int i = 0; i < shape.size(); ++i) {
cur_ix = (static_cast<int64_t>(*(indices[i] + idx)));
if (cur_ix < 0) {
cur_ix += shape[i];
}
offset += stride[i] * cur_ix;
}
*(value_grad + idx) = *(out_grad + offset);
}
}

template <typename T, typename Context>
void LaunchIndexPutGradKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& indices,
const DenseTensor& out_grad,
bool accumulate,
DenseTensor* value_grad,
DenseTensor* x_grad) {
const int64_t* pd_indices[7];
for (size_t i = 0; i < indices.size(); ++i) {
pd_indices[i] = indices[i]->data<int64_t>();
}

if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
if (!accumulate) {
T* x_grad_data = x_grad->data<T>();

auto x_grad_dims = x_grad->dims();
const int64_t numel = indices[0]->numel();
auto x_grad_stride = phi::stride(x_grad_dims);

set_zero_kernel<T>(
numel, pd_indices, x_grad_stride, x_grad_dims, x_grad_data);
}
}

auto out_grad_dims = out_grad.dims();
const int64_t numel = indices[0]->numel();
auto out_grad_stride = phi::stride(out_grad_dims);

if (value_grad) {
if (value_grad->numel() == 1) {
DenseTensor tmp_value_grad(value_grad->dtype());
tmp_value_grad.Resize(indices[0]->dims());

T* tmp_value_grad_data = dev_ctx.template Alloc<T>(&tmp_value_grad);
auto out_grad_data = out_grad.data<T>();

index_put_grad_kernel<T>(numel,
out_grad_data,
pd_indices,
out_grad_stride,
out_grad_dims,
tmp_value_grad_data);

std::vector<int> v_dims(tmp_value_grad.dims().size());
std::iota(v_dims.begin(), v_dims.end(), 0);
IntArray v_axis(v_dims);
SumKernel<T>(dev_ctx,
tmp_value_grad,
v_axis,
value_grad->dtype(),
false,
value_grad);
} else if (value_grad->numel() == indices[0]->numel()) {
T* value_grad_data = dev_ctx.template Alloc<T>(value_grad);
auto out_grad_data = out_grad.data<T>();

index_put_grad_kernel<T>(numel,
out_grad_data,
pd_indices,
out_grad_stride,
out_grad_dims,
value_grad_data);
} else {
DenseTensor tmp_value_grad(value_grad->dtype());
tmp_value_grad.Resize(indices[0]->dims());

T* tmp_value_grad_data = dev_ctx.template Alloc<T>(&tmp_value_grad);
auto out_grad_data = out_grad.data<T>();

index_put_grad_kernel<T>(numel,
out_grad_data,
pd_indices,
out_grad_stride,
out_grad_dims,
tmp_value_grad_data);

std::vector<int64_t> after_dims = phi::vectorize(tmp_value_grad.dims());
std::vector<int64_t> before_dims = phi::vectorize(value_grad->dims());
std::vector<int64_t> compress_dims;
std::vector<int64_t> dims_without_1;

funcs::CalCompressedDimsWith1AndWithout1(
&after_dims, &before_dims, &compress_dims, &dims_without_1);

auto pre_dims = value_grad->dims();
value_grad->Resize(phi::make_ddim(dims_without_1));
IntArray v_axis(compress_dims);
SumKernel<T>(dev_ctx,
tmp_value_grad,
v_axis,
value_grad->dtype(),
false,
value_grad);
value_grad->Resize(pre_dims);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数太长了,建议提取些公共代码,或者封装一些函数。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
}

template <typename T, typename Context>
void IndexPutGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<const DenseTensor*>& indices,
const DenseTensor& value,
const DenseTensor& out_grad,
bool accumulate,
DenseTensor* x_grad,
DenseTensor* value_grad) {
PADDLE_ENFORCE_EQ(
x.dtype(),
value.dtype(),
phi::errors::InvalidArgument(
"The data type of tensor in indices must be same to the data type "
"of tensor x."));
std::vector<DenseTensor> tmp_args;
std::vector<const phi::DenseTensor*> int_indices_v =
funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args);
auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v);

std::vector<int64_t> res_dim_v(phi::vectorize(bd_dim));
std::vector<const phi::DenseTensor*> res_indices_v(x.dims().size(), nullptr);
std::vector<DenseTensor> tmp_res_indices_v;
std::vector<DenseTensor> range_tensor_v;

for (int i = indices.size(); i < x.dims().size(); ++i) {
range_tensor_v.emplace_back(funcs::GetRangeTensor<int64_t, Context>(
dev_ctx, x.dims()[i], phi::DataType::INT64));
}

funcs::DealWithIndices<T, Context>(dev_ctx,
x,
int_indices_v,
&res_indices_v,
&tmp_res_indices_v,
range_tensor_v,
bd_dim,
&res_dim_v);

LaunchIndexPutGradKernel<T, Context>(
dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad);
}
} // namespace phi

PD_REGISTER_KERNEL(index_put_grad,
CPU,
ALL_LAYOUT,
phi::IndexPutGradKernel,
float,
double,
int,
int64_t,
bool) {}
Loading