diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 94a53b0ff6c920..116812b69480da 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -30,6 +30,7 @@ ParamAliasDecorator, VariableArgsDecorator, expand_decorator, + index_add_decorator, param_one_alias, param_two_alias, reshape_decorator, @@ -7599,18 +7600,38 @@ def scatter_add_( ) +@index_add_decorator() def index_add( - x: Tensor, index: Tensor, axis: int, value: Tensor, name: str | None = None + x: Tensor, + index: Tensor, + axis: int, + value: Tensor, + alpha: int = 1, + out: Tensor | None = None, + name: str | None = None, ) -> Tensor: """ Adds the elements of the input tensor with value tensor by selecting the indices in the order given in index. + .. note:: + Alias and Order Support: + 1. The parameter name ``input`` can be used as an alias for ``x``. + 2. The parameter name ``dim`` can be used as an alias for ``axis``. + 3. The parameter name ``source`` can be used as an alias for ``value``. + 4. This API also supports the PyTorch argument order ``(input, dim, index, source)`` for positional arguments, which will be converted to the Paddle order ``(x, index, axis, value)``. + For example, ``paddle.index_add(input=x, dim=1, index=idx, source=val)`` is equivalent to ``paddle.index_add(x=x, axis=1, index=idx, value=val)``, and ``paddle.index_add(x, 1, idx, val)`` is equivalent to ``paddle.index_add(x, idx, 1, val)``. + Args: x (Tensor) : The Destination Tensor. Supported data types are int32, int64, float16, float32, float64. + alias: ``input``. index (Tensor): The 1-D Tensor containing the indices to index. The data type of ``index`` must be int32 or int64. axis (int): The dimension in which we index. + alias: ``dim``. value (Tensor): The tensor used to add the elements along the target axis. + alias: ``source``. + alpha (Number, optional): Scaling factor for value. Default: 1. + out (Tensor, optional): The output tensor. Default: None. name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: @@ -7634,7 +7655,8 @@ def index_add( [2., 2., 2.]]) """ if in_dynamic_or_pir_mode(): - return _C_ops.index_add(x, index, value, axis) + scaled_value = value * alpha if alpha != 1 else value + return _C_ops.index_add(x, index, scaled_value, axis, out=out) helper = LayerHelper("index_add", **locals()) check_variable_and_dtype( @@ -7655,15 +7677,36 @@ def index_add( ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'paddle.tensor.manipulation.index_add', ) + scaled_value = ( + helper.create_variable_for_type_inference(value.dtype) + if alpha != 1 + else value + ) - out = helper.create_variable_for_type_inference(x.dtype) + if alpha != 1: + helper.append_op( + type='scale', + inputs={'X': [value]}, + outputs={'Out': [scaled_value]}, + attrs={'scale': alpha, 'bias': 0.0}, + ) + + if out is not None: + check_variable_and_dtype( + out, + 'out', + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], + 'paddle.tensor.manipulation.index_add', + ) + else: + out = helper.create_variable_for_type_inference(x.dtype) helper.append_op( type='index_add', inputs={ 'X': x, 'Index': index, - 'AddValue': value, + 'AddValue': scaled_value, }, outputs={'Out': out}, attrs={'axis': axis}, @@ -7671,15 +7714,22 @@ def index_add( return out +@index_add_decorator() @inplace_apis_in_dygraph_only def index_add_( - x: Tensor, index: Tensor, axis: int, value: Tensor, name: str | None = None + x: Tensor, + index: Tensor, + axis: int, + value: Tensor, + alpha: int = 1, + name: str | None = None, ) -> Tensor: """ Inplace version of ``index_add`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_paddle_index_add`. """ - return _C_ops.index_add_(x, index, value, axis) + scaled_value = value * alpha if alpha != 1 else value + return _C_ops.index_add_(x, index, scaled_value, axis) @inplace_apis_in_dygraph_only diff --git a/python/paddle/utils/decorator_utils.py b/python/paddle/utils/decorator_utils.py index cb22ec87955d54..4c3e02dd8ddeda 100644 --- a/python/paddle/utils/decorator_utils.py +++ b/python/paddle/utils/decorator_utils.py @@ -734,3 +734,64 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: return wrapper return decorator + + +def index_add_decorator() -> Callable[ + [Callable[_InputT, _RetT]], Callable[_InputT, _RetT] +]: + """ + Usage Example: + PyTorch: index_add(input, dim, index, source, *, alpha=1) + torch.index_add(input_tensor, 1, indices, source_tensor) + + Paddle: index_add(x, index, axis, value, alpha=1) + paddle.index_add(x=input_tensor, index=indices, axis=1, value=source_tensor) + paddle.index_add(input_tensor, indices, 1, source_tensor) + """ + + def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]: + @functools.wraps(func) + def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: + if "input" in kwargs and "x" not in kwargs: + kwargs["x"] = kwargs.pop("input") + if "dim" in kwargs and "axis" not in kwargs: + kwargs["axis"] = kwargs.pop("dim") + if "source" in kwargs and "value" not in kwargs: + kwargs["value"] = kwargs.pop("source") + + if len(args) >= 2 and isinstance(args[1], int): + if len(args) < 3 and "index" not in kwargs: + raise TypeError( + "index_add() missing 1 required positional argument: 'index'" + ) + if ( + len(args) < 4 + and "source" not in kwargs + and "value" not in kwargs + ): + raise TypeError( + "index_add() missing 1 required positional argument: 'source'" + ) + if "x" not in kwargs: + kwargs["x"] = args[0] + if "axis" not in kwargs: + kwargs["axis"] = args[1] + if len(args) == 3: + if "index" not in kwargs: + kwargs["index"] = args[2] + args = args[3:] + elif len(args) >= 4: + if "index" not in kwargs: + kwargs["index"] = args[2] + if "value" not in kwargs: + kwargs["value"] = args[3] + args = args[4:] + else: + args = args[2:] + + return func(*args, **kwargs) + + wrapper.__signature__ = inspect.signature(func) + return wrapper + + return decorator diff --git a/test/legacy_test/test_index_add_op.py b/test/legacy_test/test_index_add_op.py index 2c5aae7fd77e68..63e71c283192d6 100644 --- a/test/legacy_test/test_index_add_op.py +++ b/test/legacy_test/test_index_add_op.py @@ -554,5 +554,162 @@ def test_check_grad_normal(self): ) +def get_places(): + places = [] + if paddle.base.is_compiled_with_cuda() or is_custom_device(): + places.append(get_device_place()) + places.append(paddle.CPUPlace()) + return places + + +class TestIndexAddAPI_Compatibility(unittest.TestCase): + def setUp(self): + np.random.seed(2025) + self.places = get_places() + self.shape = [10, 20] + self.index_shape = [5] + self.axis = 1 + self.dtype = "float32" + self.value_shape = list(self.shape) + self.value_shape[self.axis] = self.index_shape[0] + self.init_data() + + def init_data(self): + self.np_input = np.random.rand(*self.shape).astype(self.dtype) + self.np_index = np.random.randint( + 0, self.shape[self.axis], self.index_shape + ).astype("int64") + self.np_value = np.random.rand(*self.value_shape).astype(self.dtype) + + def get_ref_out(self, alpha=1.0): + ref_out = np.copy(self.np_input) + idx = [slice(None)] * len(self.shape) + idx[self.axis] = self.np_index + np.add.at(ref_out, tuple(idx), self.np_value * alpha) + return ref_out + + def test_dygraph_Compatibility(self): + paddle.disable_static() + x = paddle.to_tensor(self.np_input) + index = paddle.to_tensor(self.np_index) + value = paddle.to_tensor(self.np_value) + paddle_dygraph_out = [] + + ref_out = self.get_ref_out() + # 1. Position args (Paddle style: x, index, axis, value) + out1 = paddle.index_add(x, index, self.axis, value) + paddle_dygraph_out.append(out1) + # 2. Key words args (kwargs) for paddle + out2 = paddle.index_add(x=x, index=index, axis=self.axis, value=value) + paddle_dygraph_out.append(out2) + # 3. Key words args (kwargs) for torch + out3 = paddle.index_add( + input=x, dim=self.axis, index=index, source=value + ) + paddle_dygraph_out.append(out3) + # 4. PyTorch positional args order: (input, dim, index, source) + out4 = paddle.index_add(x, self.axis, index, value) + paddle_dygraph_out.append(out4) + # 5. Tensor method args (Paddle style) + out5 = x.index_add(index, self.axis, value) + paddle_dygraph_out.append(out5) + # 6. Tensor method kwargs (PyTorch style) + out6 = x.index_add(dim=self.axis, index=index, source=value) + paddle_dygraph_out.append(out6) + # 7. Mixed args and kwargs (PyTorch positional + kwargs) + out7 = paddle.index_add(x, self.axis, index, source=value) + paddle_dygraph_out.append(out7) + # 8. Mixed args and kwargs (PyTorch positional + kwargs) + out8 = paddle.index_add(x, self.axis, index=index, source=value) + paddle_dygraph_out.append(out8) + # 9. Test 'out' parameter + out9 = paddle.empty_like(x) + paddle.index_add( + input=x, dim=self.axis, index=index, source=value, out=out9 + ) + paddle_dygraph_out.append(out9) + # 10. Test 'alpha' parameter + alpha = 2.0 + out10 = paddle.index_add( + input=x, dim=self.axis, index=index, source=value, alpha=alpha + ) + ref_out_alpha = self.get_ref_out(alpha=alpha) + + for out in paddle_dygraph_out: + np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-05) + np.testing.assert_allclose(ref_out_alpha, out10.numpy(), rtol=1e-05) + paddle.enable_static() + + def test_static_Compatibility(self): + paddle.enable_static() + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.base.program_guard(main, startup): + x = paddle.static.data(name="x", shape=self.shape, dtype=self.dtype) + index = paddle.static.data( + name="index", shape=self.index_shape, dtype="int64" + ) + value = paddle.static.data( + name="value", shape=self.value_shape, dtype=self.dtype + ) + # 1. Position args (Paddle style: x, index, axis, value) + out1 = paddle.index_add(x, index, self.axis, value) + # 2. Key words args (kwargs) for paddle + out2 = paddle.index_add( + x=x, index=index, axis=self.axis, value=value + ) + # 3. Key words args (kwargs) for torch + out3 = paddle.index_add( + input=x, dim=self.axis, index=index, source=value + ) + # 4. PyTorch positional args order: (input, dim, index, source) + out4 = paddle.index_add(x, self.axis, index, value) + # 5. Tensor method args (Paddle style) + out5 = x.index_add(index, self.axis, value) + # 6. Tensor method kwargs (PyTorch style) + out6 = x.index_add(dim=self.axis, index=index, source=value) + # 7. Mixed args and kwargs + out7 = paddle.index_add(x, self.axis, index, source=value) + # 8. Mixed args and kwargs + out8 = paddle.index_add(x, self.axis, index=index, source=value) + # 9. Test 'alpha' parameter + alpha = 2.0 + out9 = paddle.index_add( + input=x, dim=self.axis, index=index, source=value, alpha=alpha + ) + ref_out = self.get_ref_out() + ref_out_alpha = self.get_ref_out(alpha=alpha) + + fetch_list = [ + out1, + out2, + out3, + out4, + out5, + out6, + out7, + out8, + out9, + ] + feed_dict = { + "x": self.np_input, + "index": self.np_index, + "value": self.np_value, + } + + for place in self.places: + exe = paddle.base.Executor(place) + fetches = exe.run( + main, + feed=feed_dict, + fetch_list=fetch_list, + ) + for out in fetches[:-1]: + np.testing.assert_allclose(out, ref_out, rtol=1e-05) + np.testing.assert_allclose( + fetches[-1], ref_out_alpha, rtol=1e-05 + ) + + if __name__ == '__main__': unittest.main()