Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
ParamAliasDecorator,
param_one_alias,
param_two_alias,
sum_decorator,
)
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only

Expand Down Expand Up @@ -1603,6 +1604,7 @@ def fmin(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
return _elementwise_op(LayerHelper('elementwise_fmin', **locals()))


@sum_decorator()
def sum(
x: Tensor,
axis: int | Sequence[int] | None = None,
Expand All @@ -1613,14 +1615,22 @@ def sum(
"""
Computes the sum of tensor elements over the given dimension.

.. note::
Parameter order support: When passing positional parameters, it is possible to support swapping the positional order of dtype and axis.
For example, ``sum(x, axis, keepdim, dtype)`` is equivalent to ``sum(x, axis, dtype, keepdim)``.
Alias Support: The parameter name ``input`` can be used as an alias for ``x`` and the parameter name ``dim`` can be used as an alias for ``axis``.
For example, ``sum(input=tensor_x, dim=1)`` is equivalent to ``sum(x=tensor_x, axis=1)``.

Args:
x (Tensor): An N-D Tensor, the data type is bool, bfloat16, float16, float32, float64,
uint8, int8, int16, int32, int64, complex64, complex128.
alias: ``input``.
axis (int|list|tuple|None, optional): The dimensions along which the sum is performed. If
:attr:`None`, sum all elements of :attr:`x` and return a
Tensor with a single element, otherwise must be in the
range :math:`[-rank(x), rank(x))`. If :math:`axis[i] < 0`,
the dimension to reduce is :math:`rank + axis[i]`.
alias: ``dim``.
dtype (str|paddle.dtype|np.dtype, optional): The dtype of output Tensor. The default value is None, the dtype
of output is the same as input Tensor `x`.
keepdim (bool, optional): Whether to reserve the reduced dimension in the
Expand Down
35 changes: 35 additions & 0 deletions python/paddle/utils/decorator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,3 +538,38 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
return wrapper

return decorator


def sum_decorator():
def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
@functools.wraps(func)
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
if ("input" in kwargs) and ("x" not in kwargs):
kwargs["x"] = kwargs.pop("input")
if ("dim" in kwargs) and ("axis" not in kwargs):
kwargs["axis"] = kwargs.pop("dim")
if len(args) == 3:
kwargs["x"] = args[0]
kwargs["axis"] = args[1]
if isinstance(args[2], bool):
kwargs["keepdim"] = args[2]
else:
kwargs["dtype"] = args[2]
args = ()
elif len(args) == 4:
kwargs["x"] = args[0]
kwargs["axis"] = args[1]
if isinstance(args[2], bool):
kwargs["keepdim"] = args[2]
kwargs["dtype"] = args[3]
else:
kwargs["dtype"] = args[2]
kwargs["keepdim"] = args[3]
args = ()

return func(*args, **kwargs)

wrapper.__signature__ = inspect.signature(func)
return wrapper

return decorator
182 changes: 182 additions & 0 deletions test/legacy_test/test_sum_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from utils import dygraph_guard, static_guard

import paddle
from paddle import enable_static


class TestSumOp_Compatibility(unittest.TestCase):
def setUp(self):
self.shape = [2, 3, 4]
self.axis = 0
self.input_dtype = 'float32'
self.test_dtypes = [
"int32",
"float32",
]

def test_dygraph(self):
with dygraph_guard():
x_paddle = paddle.ones(shape=self.shape, dtype=self.input_dtype)
for dtype_input in self.test_dtypes:
numpy_result = np.sum(
x_paddle.numpy(),
axis=self.axis,
dtype=np.dtype(dtype_input),
keepdims=False,
)

# paddle test case
paddle_result0 = paddle.sum(x_paddle, self.axis, dtype_input)
np.testing.assert_allclose(paddle_result0, numpy_result)

paddle_result1 = paddle.sum(
x_paddle, self.axis, dtype_input, False
)
np.testing.assert_allclose(paddle_result1, numpy_result)

paddle_result2 = paddle.sum(
x=x_paddle, axis=self.axis, dtype=dtype_input, keepdim=False
)
np.testing.assert_allclose(paddle_result2, numpy_result)

# torch test case
paddle_result3 = paddle.sum(
input=x_paddle, dim=self.axis, keepdim=False
)
self.assertEqual(paddle_result3.dtype, paddle.float32)

paddle_result4 = paddle.sum(
input=x_paddle,
dim=self.axis,
keepdim=False,
dtype=dtype_input,
)
np.testing.assert_allclose(paddle_result4, numpy_result)

paddle_result5 = paddle.sum(
x_paddle, self.axis, keepdim=False, dtype=dtype_input
)
np.testing.assert_allclose(paddle_result5, numpy_result)

paddle_result6 = paddle.sum(
x_paddle, self.axis, False, dtype=dtype_input
)
np.testing.assert_allclose(paddle_result6, numpy_result)

paddle_result7 = paddle.sum(
x_paddle, self.axis, False, dtype_input
)
np.testing.assert_allclose(paddle_result7, numpy_result)

paddle_result8 = paddle.sum(
x_paddle, self.axis, dtype_input, False
)
np.testing.assert_allclose(paddle_result8, numpy_result)

paddle_result9 = paddle.sum(x_paddle, self.axis, False)
self.assertEqual(paddle_result9.dtype, paddle.float32)

paddle_result10 = paddle.sum(x_paddle, self.axis, dtype_input)
np.testing.assert_allclose(paddle_result10, numpy_result)

def test_static(self):
self.test_dtypes = [
paddle.int32,
paddle.int64,
paddle.float64,
paddle.bool,
]
with static_guard():
for dtype_input in self.test_dtypes:
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x_paddle = paddle.static.data(
name='x', shape=self.shape, dtype=self.input_dtype
)

# paddle test case
paddle_result0 = paddle.sum(
x_paddle, axis=self.axis, dtype=dtype_input
)
self.assertEqual(paddle_result0.dtype, dtype_input)

paddle_result1 = paddle.sum(
x_paddle,
axis=self.axis,
dtype=dtype_input,
keepdim=False,
)
self.assertEqual(paddle_result1.dtype, dtype_input)

paddle_result2 = paddle.sum(
x=x_paddle,
axis=self.axis,
dtype=dtype_input,
keepdim=False,
)
self.assertEqual(paddle_result2.dtype, dtype_input)

# torch test case
paddle_result3 = paddle.sum(
input=x_paddle, dim=self.axis, keepdim=False
)
self.assertEqual(paddle_result3.dtype, paddle.float32)

paddle_result4 = paddle.sum(
input=x_paddle,
dim=self.axis,
keepdim=False,
dtype=dtype_input,
)
self.assertEqual(paddle_result4.dtype, dtype_input)

paddle_result5 = paddle.sum(
x_paddle, self.axis, keepdim=False, dtype=dtype_input
)
self.assertEqual(paddle_result5.dtype, dtype_input)

paddle_result6 = paddle.sum(
x_paddle, self.axis, False, dtype=dtype_input
)
self.assertEqual(paddle_result6.dtype, dtype_input)

paddle_result7 = paddle.sum(
x_paddle, self.axis, False, dtype_input
)
self.assertEqual(paddle_result7.dtype, dtype_input)

paddle_result8 = paddle.sum(
x_paddle, self.axis, dtype_input, False
)
self.assertEqual(paddle_result8.dtype, dtype_input)

paddle_result9 = paddle.sum(x_paddle, self.axis, False)
self.assertEqual(paddle_result9.dtype, paddle.float32)

paddle_result10 = paddle.sum(
x_paddle, self.axis, dtype_input
)
self.assertEqual(paddle_result10.dtype, dtype_input)


if __name__ == "__main__":
enable_static()
unittest.main()