-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add index_put api #52886
Merged
Xreki
merged 28 commits into
PaddlePaddle:develop
from
Courtesy-Xs:clear_add_index_put_api
May 10, 2023
Merged
add index_put api #52886
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
21c5464
add index_put api
Courtesy-Xs a75ded8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Courtesy-Xs 91c30e6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Courtesy-Xs 9da71b6
fix some bugs
Courtesy-Xs 4538c1a
fix value broadcast in backward and add test case in static
Courtesy-Xs 244d02d
fix cpu backward bug
Courtesy-Xs 01672f8
add timeout=120s for index_put
Courtesy-Xs 5a361ea
add op_compat for index_put
Courtesy-Xs a7f2d42
delete input_put in op_compat.yaml
Courtesy-Xs d996d36
add inplace index_put test
Courtesy-Xs 8a3fef4
refactor code
Courtesy-Xs 5f77bb5
add test case when index tensor in indices is int32 when indices.size…
Courtesy-Xs 6267d32
add index_put api backward in cpu place
Courtesy-Xs fdd0436
add backward test case
Courtesy-Xs 86d6cac
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Courtesy-Xs 7b71a3a
fix take in init.py bug
Courtesy-Xs 48a03c6
refactor code according to review result
Courtesy-Xs 9b2d455
alter 2022 to 2023 in copyright declaration
Courtesy-Xs 0c6545a
refactor code to delete some duplicated code
Courtesy-Xs 894adb1
replaace reshape with resize for decrease extra memcpy
Courtesy-Xs ed7a141
add datatype flag in backward yaml
Courtesy-Xs c92f75e
replace macro with template with conditional complilation
Courtesy-Xs 4de9b48
fix rocmn bug
Courtesy-Xs ed00d81
fix note and rocmn bug
Courtesy-Xs f956aee
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Courtesy-Xs 43167ab
fix conflict between flatten and index_put
Courtesy-Xs b09221f
fix bug in documentation
Courtesy-Xs db0209f
Update python/paddle/tensor/manipulation.py
Ligoml File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/index_put_grad_kernel.h" | ||
#include <numeric> | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/cast_kernel.h" | ||
#include "paddle/phi/kernels/funcs/index_put_utils.h" | ||
#include "paddle/phi/kernels/reduce_sum_kernel.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T> | ||
void set_zero_kernel(const int64_t N, | ||
const int64_t** indices, | ||
const phi::DDim& stride, | ||
const phi::DDim& shape, | ||
T* out) { | ||
#ifdef PADDLE_WITH_MKLML | ||
#pragma omp parallel for | ||
#endif | ||
for (int64_t idx = 0; idx < N; ++idx) { | ||
int64_t cur_ix = 0; | ||
int64_t offset = 0; | ||
|
||
for (int i = 0; i < shape.size(); ++i) { | ||
cur_ix = (static_cast<int64_t>(*(indices[i] + idx))); | ||
if (cur_ix < 0) { | ||
cur_ix += shape[i]; | ||
} | ||
offset += stride[i] * cur_ix; | ||
} | ||
*(out + offset) = 0; | ||
} | ||
} | ||
|
||
template <typename T> | ||
void index_put_grad_kernel(const int64_t N, | ||
const T* out_grad, | ||
const int64_t** indices, | ||
const phi::DDim& stride, | ||
const phi::DDim& shape, | ||
T* value_grad) { | ||
#ifdef PADDLE_WITH_MKLML | ||
#pragma omp parallel for | ||
#endif | ||
for (int64_t idx = 0; idx < N; ++idx) { | ||
int64_t cur_ix = 0; | ||
int64_t offset = 0; | ||
|
||
for (int i = 0; i < shape.size(); ++i) { | ||
cur_ix = (static_cast<int64_t>(*(indices[i] + idx))); | ||
if (cur_ix < 0) { | ||
cur_ix += shape[i]; | ||
} | ||
offset += stride[i] * cur_ix; | ||
} | ||
*(value_grad + idx) = *(out_grad + offset); | ||
} | ||
} | ||
|
||
template <typename T, typename Context> | ||
void LaunchIndexPutGradKernel(const Context& dev_ctx, | ||
const std::vector<const DenseTensor*>& indices, | ||
const DenseTensor& out_grad, | ||
bool accumulate, | ||
DenseTensor* value_grad, | ||
DenseTensor* x_grad) { | ||
const int64_t* pd_indices[7]; | ||
for (size_t i = 0; i < indices.size(); ++i) { | ||
pd_indices[i] = indices[i]->data<int64_t>(); | ||
} | ||
|
||
if (x_grad) { | ||
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); | ||
if (!accumulate) { | ||
T* x_grad_data = x_grad->data<T>(); | ||
|
||
auto x_grad_dims = x_grad->dims(); | ||
const int64_t numel = indices[0]->numel(); | ||
auto x_grad_stride = phi::stride(x_grad_dims); | ||
|
||
set_zero_kernel<T>( | ||
numel, pd_indices, x_grad_stride, x_grad_dims, x_grad_data); | ||
} | ||
} | ||
|
||
auto out_grad_dims = out_grad.dims(); | ||
const int64_t numel = indices[0]->numel(); | ||
auto out_grad_stride = phi::stride(out_grad_dims); | ||
|
||
if (value_grad) { | ||
if (value_grad->numel() == 1) { | ||
DenseTensor tmp_value_grad(value_grad->dtype()); | ||
tmp_value_grad.Resize(indices[0]->dims()); | ||
|
||
T* tmp_value_grad_data = dev_ctx.template Alloc<T>(&tmp_value_grad); | ||
auto out_grad_data = out_grad.data<T>(); | ||
|
||
index_put_grad_kernel<T>(numel, | ||
out_grad_data, | ||
pd_indices, | ||
out_grad_stride, | ||
out_grad_dims, | ||
tmp_value_grad_data); | ||
|
||
std::vector<int> v_dims(tmp_value_grad.dims().size()); | ||
std::iota(v_dims.begin(), v_dims.end(), 0); | ||
IntArray v_axis(v_dims); | ||
SumKernel<T>(dev_ctx, | ||
tmp_value_grad, | ||
v_axis, | ||
value_grad->dtype(), | ||
false, | ||
value_grad); | ||
} else if (value_grad->numel() == indices[0]->numel()) { | ||
T* value_grad_data = dev_ctx.template Alloc<T>(value_grad); | ||
auto out_grad_data = out_grad.data<T>(); | ||
|
||
index_put_grad_kernel<T>(numel, | ||
out_grad_data, | ||
pd_indices, | ||
out_grad_stride, | ||
out_grad_dims, | ||
value_grad_data); | ||
} else { | ||
DenseTensor tmp_value_grad(value_grad->dtype()); | ||
tmp_value_grad.Resize(indices[0]->dims()); | ||
|
||
T* tmp_value_grad_data = dev_ctx.template Alloc<T>(&tmp_value_grad); | ||
auto out_grad_data = out_grad.data<T>(); | ||
|
||
index_put_grad_kernel<T>(numel, | ||
out_grad_data, | ||
pd_indices, | ||
out_grad_stride, | ||
out_grad_dims, | ||
tmp_value_grad_data); | ||
|
||
std::vector<int64_t> after_dims = phi::vectorize(tmp_value_grad.dims()); | ||
std::vector<int64_t> before_dims = phi::vectorize(value_grad->dims()); | ||
std::vector<int64_t> compress_dims; | ||
std::vector<int64_t> dims_without_1; | ||
|
||
funcs::CalCompressedDimsWith1AndWithout1( | ||
&after_dims, &before_dims, &compress_dims, &dims_without_1); | ||
|
||
auto pre_dims = value_grad->dims(); | ||
value_grad->Resize(phi::make_ddim(dims_without_1)); | ||
IntArray v_axis(compress_dims); | ||
SumKernel<T>(dev_ctx, | ||
tmp_value_grad, | ||
v_axis, | ||
value_grad->dtype(), | ||
false, | ||
value_grad); | ||
value_grad->Resize(pre_dims); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个函数太长了,建议提取些公共代码,或者封装一些函数。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
} | ||
} | ||
|
||
template <typename T, typename Context> | ||
void IndexPutGradKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
const std::vector<const DenseTensor*>& indices, | ||
const DenseTensor& value, | ||
const DenseTensor& out_grad, | ||
bool accumulate, | ||
DenseTensor* x_grad, | ||
DenseTensor* value_grad) { | ||
PADDLE_ENFORCE_EQ( | ||
x.dtype(), | ||
value.dtype(), | ||
phi::errors::InvalidArgument( | ||
"The data type of tensor in indices must be same to the data type " | ||
"of tensor x.")); | ||
std::vector<DenseTensor> tmp_args; | ||
std::vector<const phi::DenseTensor*> int_indices_v = | ||
funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args); | ||
auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v); | ||
|
||
std::vector<int64_t> res_dim_v(phi::vectorize(bd_dim)); | ||
std::vector<const phi::DenseTensor*> res_indices_v(x.dims().size(), nullptr); | ||
std::vector<DenseTensor> tmp_res_indices_v; | ||
std::vector<DenseTensor> range_tensor_v; | ||
|
||
for (int i = indices.size(); i < x.dims().size(); ++i) { | ||
range_tensor_v.emplace_back(funcs::GetRangeTensor<int64_t, Context>( | ||
dev_ctx, x.dims()[i], phi::DataType::INT64)); | ||
} | ||
|
||
funcs::DealWithIndices<T, Context>(dev_ctx, | ||
x, | ||
int_indices_v, | ||
&res_indices_v, | ||
&tmp_res_indices_v, | ||
range_tensor_v, | ||
bd_dim, | ||
&res_dim_v); | ||
|
||
LaunchIndexPutGradKernel<T, Context>( | ||
dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); | ||
} | ||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(index_put_grad, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::IndexPutGradKernel, | ||
float, | ||
double, | ||
int, | ||
int64_t, | ||
bool) {} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
输入x和indices的数据类型不同,需要指定按照谁的数据类型来选择kernel,关键字为
data_type
,写法如后面紧跟的index_sample
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done