Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add index_put api #52886

Merged
merged 28 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
21c5464
add index_put api
Courtesy-Xs Apr 13, 2023
a75ded8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Courtesy-Xs Apr 13, 2023
91c30e6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Courtesy-Xs Apr 13, 2023
9da71b6
fix some bugs
Courtesy-Xs Apr 14, 2023
4538c1a
fix value broadcast in backward and add test case in static
Courtesy-Xs Apr 16, 2023
244d02d
fix cpu backward bug
Courtesy-Xs Apr 17, 2023
01672f8
add timeout=120s for index_put
Courtesy-Xs Apr 17, 2023
5a361ea
add op_compat for index_put
Courtesy-Xs Apr 17, 2023
a7f2d42
delete input_put in op_compat.yaml
Courtesy-Xs Apr 17, 2023
d996d36
add inplace index_put test
Courtesy-Xs Apr 17, 2023
8a3fef4
refactor code
Courtesy-Xs Apr 18, 2023
5f77bb5
add test case when index tensor in indices is int32 when indices.size…
Courtesy-Xs Apr 18, 2023
6267d32
add index_put api backward in cpu place
Courtesy-Xs Apr 18, 2023
fdd0436
add backward test case
Courtesy-Xs Apr 18, 2023
86d6cac
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Courtesy-Xs Apr 28, 2023
7b71a3a
fix take in init.py bug
Courtesy-Xs Apr 28, 2023
48a03c6
refactor code according to review result
Courtesy-Xs May 6, 2023
9b2d455
alter 2022 to 2023 in copyright declaration
Courtesy-Xs May 6, 2023
0c6545a
refactor code to delete some duplicated code
Courtesy-Xs May 6, 2023
894adb1
replaace reshape with resize for decrease extra memcpy
Courtesy-Xs May 8, 2023
ed7a141
add datatype flag in backward yaml
Courtesy-Xs May 8, 2023
c92f75e
replace macro with template with conditional complilation
Courtesy-Xs May 8, 2023
4de9b48
fix rocmn bug
Courtesy-Xs May 9, 2023
ed00d81
fix note and rocmn bug
Courtesy-Xs May 9, 2023
f956aee
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Courtesy-Xs May 9, 2023
43167ab
fix conflict between flatten and index_put
Courtesy-Xs May 9, 2023
b09221f
fix bug in documentation
Courtesy-Xs May 9, 2023
db0209f
Update python/paddle/tensor/manipulation.py
Ligoml May 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,16 @@
data_type : out_grad
inplace : (out_grad -> x_grad)

- backward_op : index_put_grad
forward : index_put (Tensor x, Tensor[] indices, Tensor value, bool accumulate=false) -> Tensor(out)
args : (Tensor x, Tensor[] indices, Tensor value, Tensor out_grad, bool accumulate=false)
output : Tensor(x_grad), Tensor(value_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, value]
kernel :
func : index_put_grad

- backward_op : index_sample_grad
forward : index_sample (Tensor x, Tensor index) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad)
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,16 @@
inplace : (x -> out)
backward : index_add_grad

- op : index_put
args : (Tensor x, Tensor[] indices, Tensor value, bool accumulate=false)
output : Tensor(out)
infer_meta :
func : IndexPutInferMeta
kernel :
func : index_put
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

输入x和indices的数据类型不同,需要指定按照谁的数据类型来选择kernel,关键字为data_type,写法如后面紧跟的index_sample

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

inplace : (x -> out)
backward : index_put_grad

- op : index_sample
args : (Tensor x, Tensor index)
output : Tensor
Expand Down
16 changes: 15 additions & 1 deletion paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3249,6 +3249,21 @@ void MoeInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}

void IndexPutInferMeta(const MetaTensor& x,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InferMeta按照字母序放置

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

const std::vector<const MetaTensor*>& indices,
const MetaTensor& value,
bool accumulate,
MetaTensor* out) {
auto in_dims = x.dims();
PADDLE_ENFORCE_LT(
in_dims.size(),
7,
phi::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.",
in_dims.size()));
out->share_meta(x);
}

void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
const MetaTensor& col_ptr,
const MetaTensor& edge_weight,
Expand Down Expand Up @@ -3295,6 +3310,5 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
out_count->set_dims({-1});
out_count->set_dtype(DataType::INT32);
}

} // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -615,4 +615,10 @@ void MoeInferMeta(const MetaTensor& x,
const std::string& act_type,
MetaTensor* out);

void IndexPutInferMeta(const MetaTensor& x,
const std::vector<const MetaTensor*>& indices,
const MetaTensor& value,
bool accumulate,
MetaTensor* out);

} // namespace phi
Loading