Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 56 additions & 6 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ParamAliasDecorator,
VariableArgsDecorator,
expand_decorator,
index_add_decorator,
param_one_alias,
param_two_alias,
reshape_decorator,
Expand Down Expand Up @@ -7599,18 +7600,38 @@ def scatter_add_(
)


@index_add_decorator()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个参考下gather,使用下多种overload签名吧,这样代码可读性会好一些。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gather是把不同的参数分发给了两个ops的kernel,但是index_add是把两套参数都转化为一套参数,传入给底层的ops的kernel中,所以改成那样的方式之后,仍然减少不了大量的判断的逻辑

Copy link
Contributor

@zhwesky2010 zhwesky2010 Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gather是把不同的参数分发给了两个ops的kernel,但是index_add是把两套参数都转化为一套参数,传入给底层的ops的kernel中,所以改成那样的方式之后,仍然减少不了大量的判断的逻辑

@fangfangssj 只提升一下代码的可读性,代码逻辑不变。将多套签名在代码里展示出来。参考#76149

def index_add(
x: Tensor, index: Tensor, axis: int, value: Tensor, name: str | None = None
x: Tensor,
index: Tensor,
axis: int,
value: Tensor,
alpha: int = 1,
out: Tensor | None = None,
name: str | None = None,
) -> Tensor:
"""
Adds the elements of the input tensor with value tensor by selecting the indices in the order given in index.

.. note::
Alias and Order Support:
1. The parameter name ``input`` can be used as an alias for ``x``.
2. The parameter name ``dim`` can be used as an alias for ``axis``.
3. The parameter name ``source`` can be used as an alias for ``value``.
4. This API also supports the PyTorch argument order ``(input, dim, index, source)`` for positional arguments, which will be converted to the Paddle order ``(x, index, axis, value)``.
For example, ``paddle.index_add(input=x, dim=1, index=idx, source=val)`` is equivalent to ``paddle.index_add(x=x, axis=1, index=idx, value=val)``, and ``paddle.index_add(x, 1, idx, val)`` is equivalent to ``paddle.index_add(x, idx, 1, val)``.

Args:
x (Tensor) : The Destination Tensor. Supported data types are int32, int64, float16, float32, float64.
alias: ``input``.
index (Tensor): The 1-D Tensor containing the indices to index.
The data type of ``index`` must be int32 or int64.
axis (int): The dimension in which we index.
alias: ``dim``.
value (Tensor): The tensor used to add the elements along the target axis.
alias: ``source``.
alpha (Number, optional): Scaling factor for value. Default: 1.
out (Tensor, optional): The output tensor. Default: None.
name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

Returns:
Expand All @@ -7634,7 +7655,8 @@ def index_add(
[2., 2., 2.]])
"""
if in_dynamic_or_pir_mode():
return _C_ops.index_add(x, index, value, axis)
scaled_value = value * alpha if alpha != 1 else value
return _C_ops.index_add(x, index, scaled_value, axis, out=out)

helper = LayerHelper("index_add", **locals())
check_variable_and_dtype(
Expand All @@ -7655,31 +7677,59 @@ def index_add(
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'paddle.tensor.manipulation.index_add',
)
scaled_value = (
helper.create_variable_for_type_inference(value.dtype)
if alpha != 1
else value
)

out = helper.create_variable_for_type_inference(x.dtype)
if alpha != 1:
helper.append_op(
type='scale',
inputs={'X': [value]},
outputs={'Out': [scaled_value]},
attrs={'scale': alpha, 'bias': 0.0},
)

if out is not None:
check_variable_and_dtype(
out,
'out',
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'paddle.tensor.manipulation.index_add',
)
else:
out = helper.create_variable_for_type_inference(x.dtype)

helper.append_op(
type='index_add',
inputs={
'X': x,
'Index': index,
'AddValue': value,
'AddValue': scaled_value,
},
outputs={'Out': out},
attrs={'axis': axis},
)
return out


@index_add_decorator()
@inplace_apis_in_dygraph_only
def index_add_(
x: Tensor, index: Tensor, axis: int, value: Tensor, name: str | None = None
x: Tensor,
index: Tensor,
axis: int,
value: Tensor,
alpha: int = 1,
name: str | None = None,
) -> Tensor:
"""
Inplace version of ``index_add`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_index_add`.
"""
return _C_ops.index_add_(x, index, value, axis)
scaled_value = value * alpha if alpha != 1 else value
return _C_ops.index_add_(x, index, scaled_value, axis)


@inplace_apis_in_dygraph_only
Expand Down
61 changes: 61 additions & 0 deletions python/paddle/utils/decorator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,3 +734,64 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
return wrapper

return decorator


def index_add_decorator() -> Callable[
[Callable[_InputT, _RetT]], Callable[_InputT, _RetT]
]:
"""
Usage Example:
PyTorch: index_add(input, dim, index, source, *, alpha=1)
torch.index_add(input_tensor, 1, indices, source_tensor)

Paddle: index_add(x, index, axis, value, alpha=1)
paddle.index_add(x=input_tensor, index=indices, axis=1, value=source_tensor)
paddle.index_add(input_tensor, indices, 1, source_tensor)
"""

def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
@functools.wraps(func)
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
if "input" in kwargs and "x" not in kwargs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些判断一般只用假设一个条件,要么是torch的这种用法,要么是paddle的这种用法。其他情况无需判断。

kwargs["x"] = kwargs.pop("input")
if "dim" in kwargs and "axis" not in kwargs:
kwargs["axis"] = kwargs.pop("dim")
if "source" in kwargs and "value" not in kwargs:
kwargs["value"] = kwargs.pop("source")

if len(args) >= 2 and isinstance(args[1], int):
if len(args) < 3 and "index" not in kwargs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些报错检查可以非必要不加

raise TypeError(
"index_add() missing 1 required positional argument: 'index'"
)
if (
len(args) < 4
and "source" not in kwargs
and "value" not in kwargs
):
raise TypeError(
"index_add() missing 1 required positional argument: 'source'"
)
if "x" not in kwargs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些if not判断是否不需要,此时必然是torch的用法。

kwargs["x"] = args[0]
if "axis" not in kwargs:
kwargs["axis"] = args[1]
if len(args) == 3:
if "index" not in kwargs:
kwargs["index"] = args[2]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接写一个

args_list = ['x', 'axis', 'index', 'value']
for ele in args:
    # 顺序读取args

不用这么多判断,减少装饰器的非必要损耗。

args = args[3:]
elif len(args) >= 4:
if "index" not in kwargs:
kwargs["index"] = args[2]
if "value" not in kwargs:
kwargs["value"] = args[3]
args = args[4:]
else:
args = args[2:]

return func(*args, **kwargs)

wrapper.__signature__ = inspect.signature(func)
return wrapper

return decorator
157 changes: 157 additions & 0 deletions test/legacy_test/test_index_add_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,5 +554,162 @@ def test_check_grad_normal(self):
)


def get_places():
places = []
if paddle.base.is_compiled_with_cuda() or is_custom_device():
places.append(get_device_place())
places.append(paddle.CPUPlace())
return places


class TestIndexAddAPI_Compatibility(unittest.TestCase):
def setUp(self):
np.random.seed(2025)
self.places = get_places()
self.shape = [10, 20]
self.index_shape = [5]
self.axis = 1
self.dtype = "float32"
self.value_shape = list(self.shape)
self.value_shape[self.axis] = self.index_shape[0]
self.init_data()

def init_data(self):
self.np_input = np.random.rand(*self.shape).astype(self.dtype)
self.np_index = np.random.randint(
0, self.shape[self.axis], self.index_shape
).astype("int64")
self.np_value = np.random.rand(*self.value_shape).astype(self.dtype)

def get_ref_out(self, alpha=1.0):
ref_out = np.copy(self.np_input)
idx = [slice(None)] * len(self.shape)
idx[self.axis] = self.np_index
np.add.at(ref_out, tuple(idx), self.np_value * alpha)
return ref_out

def test_dygraph_Compatibility(self):
paddle.disable_static()
x = paddle.to_tensor(self.np_input)
index = paddle.to_tensor(self.np_index)
value = paddle.to_tensor(self.np_value)
paddle_dygraph_out = []

ref_out = self.get_ref_out()
# 1. Position args (Paddle style: x, index, axis, value)
out1 = paddle.index_add(x, index, self.axis, value)
paddle_dygraph_out.append(out1)
# 2. Key words args (kwargs) for paddle
out2 = paddle.index_add(x=x, index=index, axis=self.axis, value=value)
paddle_dygraph_out.append(out2)
# 3. Key words args (kwargs) for torch
out3 = paddle.index_add(
input=x, dim=self.axis, index=index, source=value
)
paddle_dygraph_out.append(out3)
# 4. PyTorch positional args order: (input, dim, index, source)
out4 = paddle.index_add(x, self.axis, index, value)
paddle_dygraph_out.append(out4)
# 5. Tensor method args (Paddle style)
out5 = x.index_add(index, self.axis, value)
paddle_dygraph_out.append(out5)
# 6. Tensor method kwargs (PyTorch style)
out6 = x.index_add(dim=self.axis, index=index, source=value)
paddle_dygraph_out.append(out6)
# 7. Mixed args and kwargs (PyTorch positional + kwargs)
out7 = paddle.index_add(x, self.axis, index, source=value)
paddle_dygraph_out.append(out7)
# 8. Mixed args and kwargs (PyTorch positional + kwargs)
out8 = paddle.index_add(x, self.axis, index=index, source=value)
paddle_dygraph_out.append(out8)
# 9. Test 'out' parameter
out9 = paddle.empty_like(x)
paddle.index_add(
input=x, dim=self.axis, index=index, source=value, out=out9
)
paddle_dygraph_out.append(out9)
# 10. Test 'alpha' parameter
alpha = 2.0
out10 = paddle.index_add(
input=x, dim=self.axis, index=index, source=value, alpha=alpha
)
ref_out_alpha = self.get_ref_out(alpha=alpha)

for out in paddle_dygraph_out:
np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-05)
np.testing.assert_allclose(ref_out_alpha, out10.numpy(), rtol=1e-05)
paddle.enable_static()

def test_static_Compatibility(self):
paddle.enable_static()
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.base.program_guard(main, startup):
x = paddle.static.data(name="x", shape=self.shape, dtype=self.dtype)
index = paddle.static.data(
name="index", shape=self.index_shape, dtype="int64"
)
value = paddle.static.data(
name="value", shape=self.value_shape, dtype=self.dtype
)
# 1. Position args (Paddle style: x, index, axis, value)
out1 = paddle.index_add(x, index, self.axis, value)
# 2. Key words args (kwargs) for paddle
out2 = paddle.index_add(
x=x, index=index, axis=self.axis, value=value
)
# 3. Key words args (kwargs) for torch
out3 = paddle.index_add(
input=x, dim=self.axis, index=index, source=value
)
# 4. PyTorch positional args order: (input, dim, index, source)
out4 = paddle.index_add(x, self.axis, index, value)
# 5. Tensor method args (Paddle style)
out5 = x.index_add(index, self.axis, value)
# 6. Tensor method kwargs (PyTorch style)
out6 = x.index_add(dim=self.axis, index=index, source=value)
# 7. Mixed args and kwargs
out7 = paddle.index_add(x, self.axis, index, source=value)
# 8. Mixed args and kwargs
out8 = paddle.index_add(x, self.axis, index=index, source=value)
# 9. Test 'alpha' parameter
alpha = 2.0
out9 = paddle.index_add(
input=x, dim=self.axis, index=index, source=value, alpha=alpha
)
ref_out = self.get_ref_out()
ref_out_alpha = self.get_ref_out(alpha=alpha)

fetch_list = [
out1,
out2,
out3,
out4,
out5,
out6,
out7,
out8,
out9,
]
feed_dict = {
"x": self.np_input,
"index": self.np_index,
"value": self.np_value,
}

for place in self.places:
exe = paddle.base.Executor(place)
fetches = exe.run(
main,
feed=feed_dict,
fetch_list=fetch_list,
)
for out in fetches[:-1]:
np.testing.assert_allclose(out, ref_out, rtol=1e-05)
np.testing.assert_allclose(
fetches[-1], ref_out_alpha, rtol=1e-05
)


if __name__ == '__main__':
unittest.main()
Loading