Skip to content

Commit d434c9e

Browse files
authored
[API compatibility] add scatter_reduce api (#74564)
* add scatter reduce api * cancel paramAliasDecorator * add keyword-only * fix test scatter reduce * fix test note * fix testscase and static check
1 parent 5e376d8 commit d434c9e

File tree

4 files changed

+1214
-0
lines changed

4 files changed

+1214
-0
lines changed

python/paddle/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@
370370
scatter_,
371371
scatter_nd,
372372
scatter_nd_add,
373+
scatter_reduce,
373374
select_scatter,
374375
shard_index,
375376
slice,
@@ -1231,6 +1232,7 @@
12311232
'renorm',
12321233
'renorm_',
12331234
'take_along_axis',
1235+
'scatter_reduce',
12341236
'put_along_axis',
12351237
'select_scatter',
12361238
'multigammaln',

python/paddle/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@
207207
scatter_,
208208
scatter_nd,
209209
scatter_nd_add,
210+
scatter_reduce,
210211
select_scatter,
211212
shard_index,
212213
slice,
@@ -819,6 +820,7 @@
819820
'moveaxis',
820821
'repeat_interleave',
821822
'take_along_axis',
823+
'scatter_reduce',
822824
'put_along_axis',
823825
'select_scatter',
824826
'put_along_axis_',

python/paddle/tensor/manipulation.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6942,6 +6942,68 @@ def take_along_axis(
69426942
return result
69436943

69446944

6945+
def scatter_reduce(
6946+
input: Tensor,
6947+
dim: int,
6948+
index: Tensor,
6949+
src: Tensor,
6950+
reduce: Literal['sum', 'prod', 'mean', 'amin', 'amax'],
6951+
*,
6952+
include_self: bool = True,
6953+
) -> Tensor:
6954+
"""
6955+
Scatter the values of the source tensor to the target tensor according to the given indices, and perform a reduction operation along the designated axis.
6956+
6957+
Args:
6958+
input (Tensor) : The Input Tensor. Supported data types are bfloat16, float16, float32, float64,
6959+
int32, int64, uint8.
6960+
dim (int) : The axis to scatter 1d slices along.
6961+
index (Tensor) : Indices to scatter along each 1d slice of input. This must match the dimension of input,
6962+
Supported data type are int32 and int64.
6963+
src (Tensor) : The value element(s) to scatter. The data types should be same as input.
6964+
reduce (str): The reduce operation, support 'sum', 'prod', 'mean', 'amin', 'amax'.
6965+
include_self (bool, optional): whether to reduce with the elements of input, default is 'True'.
6966+
6967+
Returns:
6968+
Tensor, The indexed element, same dtype with input
6969+
6970+
Examples:
6971+
.. code-block:: python
6972+
6973+
>>> import paddle
6974+
6975+
>>> x = paddle.to_tensor([[10, 20, 30], [40, 50, 60]])
6976+
>>> indices = paddle.zeros((2,3)).astype("int32")
6977+
>>> values = paddle.to_tensor([[1, 2, 3],[4, 5, 6]]).astype(x.dtype)
6978+
>>> result = paddle.scatter_reduce(x, 0, indices, values, "sum", include_self=True)
6979+
>>> print(result)
6980+
Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
6981+
[[15, 27, 39],
6982+
[40, 50, 60]])
6983+
6984+
>>> result = paddle.scatter_reduce(x, 0, indices, values, "prod", include_self=True)
6985+
>>> print(result)
6986+
Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
6987+
[[40 , 200, 540],
6988+
[40 , 50 , 60 ]])
6989+
6990+
>>> result = paddle.scatter_reduce(x, 0, indices, values, "mean", include_self=True)
6991+
>>> print(result)
6992+
Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
6993+
[[5 , 9 , 13],
6994+
[40, 50, 60]])
6995+
6996+
"""
6997+
6998+
if reduce == 'sum':
6999+
reduce = 'add'
7000+
if reduce == 'prod':
7001+
reduce = 'multiply'
7002+
return put_along_axis(
7003+
input, index, src, dim, reduce, include_self, broadcast=False
7004+
)
7005+
7006+
69457007
def put_along_axis(
69467008
arr: Tensor,
69477009
indices: Tensor,

0 commit comments

Comments
 (0)