Skip to content
7 changes: 7 additions & 0 deletions python/paddle/framework/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import paddle
from paddle.utils.decorator_utils import ParamAliasDecorator

from ..base import framework
from ..base.core import (
Expand Down Expand Up @@ -205,16 +206,22 @@ def iinfo(dtype):
return core_iinfo(dtype)


@ParamAliasDecorator({"dtype": ["type"]})
def finfo(dtype):
"""

``paddle.finfo`` is a function that returns an object that represents the numerical properties of a floating point
``paddle.dtype``.
This is similar to `numpy.finfo <https://numpy.org/doc/stable/reference/generated/numpy.finfo.html#numpy-finfo>`_.

.. note::
Alias Support: The parameter name ``type`` can be used as an alias for ``dtype``.
For example, ``type=paddle.float32`` is equivalent to ``type=paddle.float32``.

Args:
dtype(paddle.dtype|string): One of ``paddle.float16``, ``paddle.float32``, ``paddle.float64``, ``paddle.bfloat16``,
``paddle.complex64``, and ``paddle.complex128``.
type: An alias for ``dtype`` , with identical behavior.

Returns:
An ``finfo`` object, which has the following 8 attributes:
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from paddle.tensor.creation import full
from paddle.utils import deprecated
from paddle.utils.decorator_utils import ParamAliasDecorator
from paddle.utils.layers_utils import NotSupportedTensorArgumentError

from ...base.data_feeder import (
Expand Down Expand Up @@ -1788,6 +1789,7 @@ def feature_alpha_dropout(
)


@ParamAliasDecorator({"x": ["input"]})
def pad(
x: Tensor,
pad: ShapeLike,
Expand Down Expand Up @@ -1823,8 +1825,14 @@ def pad(
4. If mode is ``'reflect'``, pad[0] and pad[1] must be no greater than width-1. The height and depth
dimension has the same condition.

.. note::
Alias Support: The parameter name ``input`` can be used as an alias for ``x``.
For example, ``input=tensor_x`` is equivalent to ``x=tensor_x``.


Args:
x (Tensor): The input tensor with data type float32, float64, int32, int64, complex64 or complex128.
input: An alias for ``x`` , with identical behavior.
pad (Tensor|list[int]|tuple[int]): The padding size with data type int. Refer to Note for details.
mode (str, optional): Four modes: ``'constant'`` (default), ``'reflect'``, ``'replicate'``, ``'circular'``. Default is ``'constant'``.

Expand Down
8 changes: 8 additions & 0 deletions python/paddle/tensor/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import paddle
from paddle import _C_ops
from paddle.utils.decorator_utils import ParamAliasDecorator

from ..base.data_feeder import check_type, check_variable_and_dtype
from ..base.framework import in_dynamic_or_pir_mode, use_pir_api
Expand Down Expand Up @@ -144,11 +145,18 @@ def shape(input: Tensor) -> Tensor:
return out


@ParamAliasDecorator({"x": ["input"]})
def is_complex(x: Tensor) -> bool:
"""Return whether x is a tensor of complex data type(complex64 or complex128).


.. note::
Alias Support: The parameter name ``input`` can be used as an alias for ``x``.
For example, ``input=tensor_x`` is equivalent to ``x=tensor_x``.

Args:
x (Tensor): The input tensor.
input: An alias for ``x`` , with identical behavior.

Returns:
bool: True if the data type of the input is complex data type, otherwise false.
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import paddle
from paddle import _C_ops
from paddle.utils.decorator_utils import ParamAliasDecorator
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only

from ..base.data_feeder import (
Expand Down Expand Up @@ -876,6 +877,7 @@ def _to_tensor_static(
return output


@ParamAliasDecorator({"place": ["device"]})
def to_tensor(
data: TensorLike | NestedNumericSequence,
dtype: DTypeLike | None = None,
Expand All @@ -889,6 +891,10 @@ def to_tensor(
If the ``data`` is already a Tensor, copy will be performed and return a new tensor.
If you only want to change stop_gradient property, please call ``Tensor.stop_gradient = stop_gradient`` directly.

.. note::
Alias Support: The parameter name ``device`` can be used as an alias for ``place``.
For example, ``device=paddle.CUDAPlace(0)`` is equivalent to ``place=paddle.CUDAPlace(0)``.

.. code-block:: text

We use the dtype conversion rules following this:
Expand All @@ -911,6 +917,7 @@ def to_tensor(
place(CPUPlace|CUDAPinnedPlace|CUDAPlace|str, optional): The place to allocate Tensor. Can be
CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place. If ``place`` is
string, It can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, where ``x`` is the index of the GPUs.
device: An alias for ``place`` , with identical behavior.
stop_gradient(bool, optional): Whether to block the gradient propagation of Autograd. Default: True.

Returns:
Expand Down
32 changes: 8 additions & 24 deletions python/paddle/utils/decorator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import functools
import inspect
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
TypeVar,
cast,
)

from typing_extensions import ParamSpec

if TYPE_CHECKING:
from collections.abc import Iterable

from collections.abc import Iterable
from typing import Any, Callable, TypeVar, cast

_P = ParamSpec("_P")
_R = TypeVar("_R")
_DecoratedFunc = Callable[_P, _R]
_F = TypeVar("_F", bound=Callable[..., Any])


class DecoratorBase(Generic[_P, _R]):
class DecoratorBase:
"""Decorative base class, providing a universal decorative framework.

Subclass only needs to implement the 'process' method to define the core logic.
Expand All @@ -47,19 +31,19 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self.args = args
self.kwargs = kwargs

def __call__(self, func: _DecoratedFunc[_P, _R]) -> _DecoratedFunc[_P, _R]:
def __call__(self, func: _F) -> _F:
"""As an entry point for decorative applications"""

@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
def wrapper(*args, **kwargs):
# Pretreatment parameters
processed_args, processed_kwargs = self.process(args, kwargs)
# Call the original function
return func(*processed_args, **processed_kwargs)

# Keep original signature
wrapper.__signature__ = inspect.signature(func)
return cast("_DecoratedFunc[_P, _R]", wrapper)
return cast("_F", wrapper)

def process(
self, args: tuple[Any, ...], kwargs: dict[str, Any]
Expand All @@ -77,7 +61,7 @@ def process(


# Example implementation: Parameter alias decorator
class ParamAliasDecorator(DecoratorBase[_P, _R]):
class ParamAliasDecorator(DecoratorBase):
"""Implementation of Decorator for Parameter Alias Processing"""

def __init__(self, alias_mapping: dict[str, Iterable[str]]) -> None:
Expand Down
46 changes: 46 additions & 0 deletions test/legacy_test/test_eager_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,52 @@ def test_to_tensor_attributes(self):
self.assertEqual(var.dtype, paddle.float32)
self.assertEqual(var.type, core.VarDesc.VarType.DENSE_TENSOR)

def test_to_tensor_param_alias(self):
"""Test paddle.to_tensor parameter mapping ("place": ["device"])."""
# 1. Test equivalence of place and device parameters
tensor_place = paddle.to_tensor(self.array, place=paddle.CPUPlace())
tensor_device = paddle.to_tensor(self.array, device=paddle.CPUPlace())

np.testing.assert_array_equal(
tensor_device.numpy(), tensor_place.numpy()
)
self.assertEqual(tensor_device.place, tensor_place.place)

# 2. Test conflict between place and device (should raise KeyError)
with self.assertRaises(ValueError) as context:
paddle.to_tensor(
self.array,
place=paddle.CPUPlace(),
device=paddle.CPUPlace(), # Conflict
)
self.assertIn(
"Cannot specify both 'place' and its alias 'device'",
str(context.exception),
)

# 3. Test dtype and stop_gradient consistency
tensor1 = paddle.to_tensor(
self.array, dtype="float32", device=paddle.CPUPlace()
)
tensor2 = paddle.to_tensor(
self.array, dtype="float32", place=paddle.CPUPlace()
)

self.assertEqual(tensor1.dtype, tensor2.dtype)
self.assertEqual(tensor1.dtype, paddle.float32)
self.assertTrue(tensor1.stop_gradient)
self.assertEqual(tensor1.stop_gradient, tensor2.stop_gradient)

# 4. Test cross-device compatibility (CPU/GPU)
for device in [paddle.CPUPlace()] + (
[paddle.CUDAPlace(0)] if core.is_compiled_with_cuda() else []
):
tensor_device = paddle.to_tensor(self.array, device=device)
tensor_place = paddle.to_tensor(self.array, place=device)

self.assertEqual(tensor_device.place, tensor_place.place)
self.assertEqual(tensor_device.place, device)

def test_list_to_tensor(self):
array = [[[1, 2], [1, 2], [1.0, 2]], [[1, 2], [1, 2], [1, 2]]]
var = paddle.to_tensor(array, dtype="int32")
Expand Down
43 changes: 43 additions & 0 deletions test/legacy_test/test_iinfo_and_finfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,49 @@ def test_finfo(self):
self.assertAlmostEqual(xinfo.resolution, 0.01)
self.assertAlmostEqual(xinfo.smallest_normal, 1.1754943508222875e-38)

def test_finfo_alias(self):
# dtype and type alias
for alias_param in ["dtype", "type"]:
for paddle_dtype, np_dtype in [
(paddle.float32, np.float32),
(paddle.float64, np.float64),
('float32', np.float32),
('float64', np.float64),
]:
xinfo = paddle.finfo(**{alias_param: paddle_dtype})
xninfo = np.finfo(np_dtype)
self.assertEqual(xinfo.dtype, xninfo.dtype)
self.assertEqual(xinfo.bits, xninfo.bits)
self.assertAlmostEqual(xinfo.max, xninfo.max)
self.assertAlmostEqual(xinfo.min, xninfo.min)
self.assertAlmostEqual(xinfo.eps, xninfo.eps)
self.assertAlmostEqual(xinfo.tiny, xninfo.tiny)
self.assertAlmostEqual(xinfo.resolution, xninfo.resolution)
if np.lib.NumpyVersion(np.__version__) >= "1.22.0":
self.assertAlmostEqual(
xinfo.smallest_normal, xninfo.smallest_normal
)

for paddle_dtype, np_dtype in [
(paddle.complex64, np.complex64),
(paddle.complex128, np.complex128),
('complex64', np.complex64),
('complex128', np.complex128),
]:
xinfo = paddle.finfo(**{alias_param: paddle_dtype})
xninfo = np.finfo(np_dtype)
self.assertEqual(xinfo.dtype, xninfo.dtype)
self.assertEqual(xinfo.bits, xninfo.bits)
self.assertAlmostEqual(xinfo.max, xninfo.max, places=16)
self.assertAlmostEqual(xinfo.min, xninfo.min, places=16)
self.assertAlmostEqual(xinfo.eps, xninfo.eps, places=16)
self.assertAlmostEqual(xinfo.tiny, xninfo.tiny, places=16)
self.assertAlmostEqual(xinfo.resolution, xninfo.resolution)
if np.lib.NumpyVersion(np.__version__) >= "1.22.0":
self.assertAlmostEqual(
xinfo.smallest_normal, xninfo.smallest_normal, places=16
)


if __name__ == '__main__':
unittest.main()
15 changes: 15 additions & 0 deletions test/legacy_test/test_is_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ def test_for_exception(self):
with self.assertRaises(TypeError):
paddle.is_complex(np.array([1, 2]))

def test_for_alias(self):
for alias_param in ["x", "input"]:
# test_for_integer
x = paddle.arange(10)
self.assertFalse(paddle.is_complex(**{alias_param: x}))
# test_for_floating_point
x = paddle.randn([2, 3])
self.assertFalse(paddle.is_complex(**{alias_param: x}))
# test_for_complex
x = paddle.randn([2, 3]) + 1j * paddle.randn([2, 3])
self.assertTrue(paddle.is_complex(**{alias_param: x}))
# test_for_exception
with self.assertRaises(TypeError):
paddle.is_complex(**{alias_param: np.array([1, 2])})


if __name__ == '__main__':
unittest.main()
58 changes: 58 additions & 0 deletions test/legacy_test/test_pad_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,64 @@ def init_case(self):
self.pad_value = 0.5


class TestPadAliasSupport(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.shape = (2, 3)
self.paddings = [1, 2, 3, 4]
self.value = 0.5
self.x = np.random.random(self.shape).astype('float32')

def test_no_param_name(self):
out = paddle.nn.functional.pad(
paddle.to_tensor(self.x), self.paddings, value=self.value
)
expected = np.pad(
self.x,
[(1, 2), (3, 4)],
mode='constant',
constant_values=self.value,
)
np.testing.assert_array_equal(out.numpy(), expected)

def test_x_param_name(self):
out = paddle.nn.functional.pad(
x=paddle.to_tensor(self.x), pad=self.paddings, value=self.value
)
expected = np.pad(
self.x,
[(1, 2), (3, 4)],
mode='constant',
constant_values=self.value,
)
np.testing.assert_array_equal(out.numpy(), expected)

def test_input_param_name(self):
out = paddle.nn.functional.pad(
input=paddle.to_tensor(self.x), pad=self.paddings, value=self.value
)
expected = np.pad(
self.x,
[(1, 2), (3, 4)],
mode='constant',
constant_values=self.value,
)
np.testing.assert_array_equal(out.numpy(), expected)

def test_both_param_name(self):
with self.assertRaises(ValueError) as context:
paddle.nn.functional.pad(
x=paddle.to_tensor(self.x),
input=paddle.to_tensor(self.x),
pad=self.paddings,
value=self.value,
)
self.assertIn(
"Cannot specify both 'x' and its alias 'input'",
str(context.exception),
)


if __name__ == "__main__":
# paddle.enable_static()
unittest.main()