diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index fd5589575a7518..b55b57158abc40 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -42,6 +42,7 @@ ParamAliasDecorator, param_one_alias, param_two_alias, + sum_decorator, ) from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only @@ -1603,6 +1604,7 @@ def fmin(x: Tensor, y: Tensor, name: str | None = None) -> Tensor: return _elementwise_op(LayerHelper('elementwise_fmin', **locals())) +@sum_decorator() def sum( x: Tensor, axis: int | Sequence[int] | None = None, @@ -1613,14 +1615,22 @@ def sum( """ Computes the sum of tensor elements over the given dimension. + .. note:: + Parameter order support: When passing positional parameters, it is possible to support swapping the positional order of dtype and axis. + For example, ``sum(x, axis, keepdim, dtype)`` is equivalent to ``sum(x, axis, dtype, keepdim)``. + Alias Support: The parameter name ``input`` can be used as an alias for ``x`` and the parameter name ``dim`` can be used as an alias for ``axis``. + For example, ``sum(input=tensor_x, dim=1)`` is equivalent to ``sum(x=tensor_x, axis=1)``. + Args: x (Tensor): An N-D Tensor, the data type is bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64, complex64, complex128. + alias: ``input``. axis (int|list|tuple|None, optional): The dimensions along which the sum is performed. If :attr:`None`, sum all elements of :attr:`x` and return a Tensor with a single element, otherwise must be in the range :math:`[-rank(x), rank(x))`. If :math:`axis[i] < 0`, the dimension to reduce is :math:`rank + axis[i]`. + alias: ``dim``. dtype (str|paddle.dtype|np.dtype, optional): The dtype of output Tensor. The default value is None, the dtype of output is the same as input Tensor `x`. keepdim (bool, optional): Whether to reserve the reduced dimension in the diff --git a/python/paddle/utils/decorator_utils.py b/python/paddle/utils/decorator_utils.py index fae116edd53ace..8243e2ff52b16e 100644 --- a/python/paddle/utils/decorator_utils.py +++ b/python/paddle/utils/decorator_utils.py @@ -538,3 +538,38 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: return wrapper return decorator + + +def sum_decorator(): + def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]: + @functools.wraps(func) + def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: + if ("input" in kwargs) and ("x" not in kwargs): + kwargs["x"] = kwargs.pop("input") + if ("dim" in kwargs) and ("axis" not in kwargs): + kwargs["axis"] = kwargs.pop("dim") + if len(args) == 3: + kwargs["x"] = args[0] + kwargs["axis"] = args[1] + if isinstance(args[2], bool): + kwargs["keepdim"] = args[2] + else: + kwargs["dtype"] = args[2] + args = () + elif len(args) == 4: + kwargs["x"] = args[0] + kwargs["axis"] = args[1] + if isinstance(args[2], bool): + kwargs["keepdim"] = args[2] + kwargs["dtype"] = args[3] + else: + kwargs["dtype"] = args[2] + kwargs["keepdim"] = args[3] + args = () + + return func(*args, **kwargs) + + wrapper.__signature__ = inspect.signature(func) + return wrapper + + return decorator diff --git a/test/legacy_test/test_sum_decorator.py b/test/legacy_test/test_sum_decorator.py new file mode 100644 index 00000000000000..10b5e03d62c3dd --- /dev/null +++ b/test/legacy_test/test_sum_decorator.py @@ -0,0 +1,182 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from utils import dygraph_guard, static_guard + +import paddle +from paddle import enable_static + + +class TestSumOp_Compatibility(unittest.TestCase): + def setUp(self): + self.shape = [2, 3, 4] + self.axis = 0 + self.input_dtype = 'float32' + self.test_dtypes = [ + "int32", + "float32", + ] + + def test_dygraph(self): + with dygraph_guard(): + x_paddle = paddle.ones(shape=self.shape, dtype=self.input_dtype) + for dtype_input in self.test_dtypes: + numpy_result = np.sum( + x_paddle.numpy(), + axis=self.axis, + dtype=np.dtype(dtype_input), + keepdims=False, + ) + + # paddle test case + paddle_result0 = paddle.sum(x_paddle, self.axis, dtype_input) + np.testing.assert_allclose(paddle_result0, numpy_result) + + paddle_result1 = paddle.sum( + x_paddle, self.axis, dtype_input, False + ) + np.testing.assert_allclose(paddle_result1, numpy_result) + + paddle_result2 = paddle.sum( + x=x_paddle, axis=self.axis, dtype=dtype_input, keepdim=False + ) + np.testing.assert_allclose(paddle_result2, numpy_result) + + # torch test case + paddle_result3 = paddle.sum( + input=x_paddle, dim=self.axis, keepdim=False + ) + self.assertEqual(paddle_result3.dtype, paddle.float32) + + paddle_result4 = paddle.sum( + input=x_paddle, + dim=self.axis, + keepdim=False, + dtype=dtype_input, + ) + np.testing.assert_allclose(paddle_result4, numpy_result) + + paddle_result5 = paddle.sum( + x_paddle, self.axis, keepdim=False, dtype=dtype_input + ) + np.testing.assert_allclose(paddle_result5, numpy_result) + + paddle_result6 = paddle.sum( + x_paddle, self.axis, False, dtype=dtype_input + ) + np.testing.assert_allclose(paddle_result6, numpy_result) + + paddle_result7 = paddle.sum( + x_paddle, self.axis, False, dtype_input + ) + np.testing.assert_allclose(paddle_result7, numpy_result) + + paddle_result8 = paddle.sum( + x_paddle, self.axis, dtype_input, False + ) + np.testing.assert_allclose(paddle_result8, numpy_result) + + paddle_result9 = paddle.sum(x_paddle, self.axis, False) + self.assertEqual(paddle_result9.dtype, paddle.float32) + + paddle_result10 = paddle.sum(x_paddle, self.axis, dtype_input) + np.testing.assert_allclose(paddle_result10, numpy_result) + + def test_static(self): + self.test_dtypes = [ + paddle.int32, + paddle.int64, + paddle.float64, + paddle.bool, + ] + with static_guard(): + for dtype_input in self.test_dtypes: + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x_paddle = paddle.static.data( + name='x', shape=self.shape, dtype=self.input_dtype + ) + + # paddle test case + paddle_result0 = paddle.sum( + x_paddle, axis=self.axis, dtype=dtype_input + ) + self.assertEqual(paddle_result0.dtype, dtype_input) + + paddle_result1 = paddle.sum( + x_paddle, + axis=self.axis, + dtype=dtype_input, + keepdim=False, + ) + self.assertEqual(paddle_result1.dtype, dtype_input) + + paddle_result2 = paddle.sum( + x=x_paddle, + axis=self.axis, + dtype=dtype_input, + keepdim=False, + ) + self.assertEqual(paddle_result2.dtype, dtype_input) + + # torch test case + paddle_result3 = paddle.sum( + input=x_paddle, dim=self.axis, keepdim=False + ) + self.assertEqual(paddle_result3.dtype, paddle.float32) + + paddle_result4 = paddle.sum( + input=x_paddle, + dim=self.axis, + keepdim=False, + dtype=dtype_input, + ) + self.assertEqual(paddle_result4.dtype, dtype_input) + + paddle_result5 = paddle.sum( + x_paddle, self.axis, keepdim=False, dtype=dtype_input + ) + self.assertEqual(paddle_result5.dtype, dtype_input) + + paddle_result6 = paddle.sum( + x_paddle, self.axis, False, dtype=dtype_input + ) + self.assertEqual(paddle_result6.dtype, dtype_input) + + paddle_result7 = paddle.sum( + x_paddle, self.axis, False, dtype_input + ) + self.assertEqual(paddle_result7.dtype, dtype_input) + + paddle_result8 = paddle.sum( + x_paddle, self.axis, dtype_input, False + ) + self.assertEqual(paddle_result8.dtype, dtype_input) + + paddle_result9 = paddle.sum(x_paddle, self.axis, False) + self.assertEqual(paddle_result9.dtype, paddle.float32) + + paddle_result10 = paddle.sum( + x_paddle, self.axis, dtype_input + ) + self.assertEqual(paddle_result10.dtype, dtype_input) + + +if __name__ == "__main__": + enable_static() + unittest.main()