Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3312,6 +3312,7 @@ def vsplit(
return tensor_split(x, num_or_indices, axis=0, name=name)


@param_two_alias(["x", "input"], ["axis", "dim"])
def squeeze(
x: Tensor, axis: int | Sequence[int] | None = None, name: str | None = None
) -> Tensor:
Expand Down Expand Up @@ -3360,12 +3361,18 @@ def squeeze(
Output:
out.shape = [1, 3, 5]

.. note::
Alias Support: The parameter name ``input`` can be used as an alias for ``x``, and ``dim`` can be used as an alias for ``axis``.
For example, ``squeeze(input=tensor_x, dim=1)`` is equivalent to ``squeeze(x=tensor_x, axis=1)``.

Args:
x (Tensor): The input Tensor. Supported data type: float32, float64, bool, int8, int32, int64.
alias: ``input``.
axis (int|list|tuple, optional): An integer or list/tuple of integers, indicating the dimensions to be squeezed. Default is None.
The range of axis is :math:`[-ndim(x), ndim(x))`.
If axis is negative, :math:`axis = axis + ndim(x)`.
If axis is None, all the dimensions of x of size 1 will be removed.
alias: ``dim``.
name (str|None, optional): Please refer to :ref:`api_guide_Name`, Default None.

Returns:
Expand Down
31 changes: 28 additions & 3 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
from paddle.base.libpaddle import DataType
from paddle.common_ops_import import VarDesc, dygraph_utils
from paddle.pir import Value
from paddle.utils.decorator_utils import ParamAliasDecorator, param_two_alias
from paddle.utils.decorator_utils import (
ParamAliasDecorator,
param_one_alias,
param_two_alias,
)
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only

from ..base.data_feeder import (
Expand Down Expand Up @@ -746,11 +750,17 @@ def add(
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0

.. note::
Alias Support: The parameter name ``input`` can be used as an alias for ``x``, and ``other`` can be used as an alias for ``y``.
For example, ``add(input=tensor_x, other=tensor_y)`` is equivalent to ``add(x=tensor_x, y=tensor_y)``.

Args:
x (Tensor): Tensor of any dimensions. Its dtype should be bool, bfloat16, float16, float32, float64,
int8, int16, int32, int64, uint8, complex64, complex128.
alias: ``input``.
y (Tensor): Tensor of any dimensions. Its dtype should be bool, bfloat16, float16, float32, float64,
int8, int16, int32, int64, uint8, complex64, complex128.
alias: ``other``.
alpha (Number, optional): Scaling factor for Y. Default: 1.
out (Tensor, optional): The output tensor. Default: None.
name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Expand Down Expand Up @@ -985,11 +995,17 @@ def divide(

.. _Introduction to Tensor: ../../guides/beginner/tensor_en.html#chapter5-broadcasting-of-tensor

.. note::
Alias Support: The parameter name ``input`` can be used as an alias for ``x``, and ``other`` can be used as an alias for ``y``.
For example, ``divide(input=tensor_x, other=tensor_y)`` is equivalent to ``divide(x=tensor_x, y=tensor_y)``.

Args:
x (Tensor): the input tensor, it's data type should be bool, bfloat16, float16, float32, float64,
int8, int16, int32, int64, uint8, complex64, complex128.
alias: ``input``.
y (Tensor): the input tensor, it's data type should be bool, bfloat16, float16, float32, float64,
int8, int16, int32, int64, uint8, complex64, complex128.
alias: ``other``.
rounding_mode (str|None, optional): The rounding mode. Can be None (default), "trunc" (truncate toward zero), or "floor" (round down toward negative infinity).
out (Tensor, optional): The output tensor. Default: None.
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Expand Down Expand Up @@ -3801,11 +3817,14 @@ def log10_(x: Tensor, name: str | None = None) -> Tensor:
return _C_ops.log10_(x)


@param_one_alias(["x", "input"])
def clip(
x: Tensor,
min: float | None = None,
max: float | None = None,
name: str | None = None,
*,
out: Tensor | None = None,
) -> Tensor:
"""
This operator clip all elements in input into the range [ min, max ] and return
Expand All @@ -3815,13 +3834,19 @@ def clip(

Out = MIN(MAX(x, min), max)

.. note::
Alias Support: The parameter name ``input`` can be used as an alias for ``x``.
For example, ``clip(input=tensor_x)`` is equivalent to ``clip(x=tensor_x)``.

Args:
x (Tensor): An N-D Tensor with data type bfloat16, float16, float32, float64, int32 or int64.
alias: ``input``.
min (float|int|Tensor, optional): The lower bound with type ``float`` , ``int`` or a ``0-D Tensor``
with shape [] and type ``bfloat16``, ``float16``, ``float32``, ``float64``, ``int32``.
max (float|int|Tensor, optional): The upper bound with type ``float``, ``int`` or a ``0-D Tensor``
with shape [] and type ``bfloat16``, ``float16``, ``float32``, ``float64``, ``int32``.
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
out (Tensor|None, optional): The output tensor. Default: None.

Returns:
Tensor: A Tensor with the same data shape as input. If either min or max is a floating-point value/Tensor, the output tensor will have a data type of ``float32``. Otherwise, the output tensor will inherit the same data type as the input.
Expand Down Expand Up @@ -3872,7 +3897,7 @@ def clip(
max = max.item(0)
if isinstance(min, float) or isinstance(max, float):
x = paddle.cast(x, paddle.float32)
return _C_ops.clip(x, min, max)
return _C_ops.clip(x, min, max, out=out)
elif in_pir_mode():
if x_dtype in ['paddle.int32', 'paddle.int64']:
if (
Expand All @@ -3888,7 +3913,7 @@ def clip(
)
):
x = paddle.cast(x, paddle.float32)
return _C_ops.clip(x, min, max)
return _C_ops.clip(x, min, max, out=out)
else:
if min is not None:
check_type(min, 'min', (float, int, Variable), 'clip')
Expand Down
158 changes: 158 additions & 0 deletions test/legacy_test/test_clip_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np
from op_test import OpTest, convert_float_to_uint16
from utils import dygraph_guard, static_guard

import paddle
from paddle import base
Expand Down Expand Up @@ -698,5 +699,162 @@ def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', check_pir=True)


class TestClipCompatibility(unittest.TestCase):
def setUp(self):
self.places = [paddle.CPUPlace()]
if paddle.base.core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))
self.func = paddle.clip
self.init_data()
self.init_case()

def init_data(self):
self.shape = [5, 6]
self.dtype = 'float32'
self.min_val = 0.3
self.max_val = 0.7
self.np_input = np.random.rand(*self.shape).astype(self.dtype)
self.np_out = np.clip(self.np_input, self.min_val, self.max_val)

def init_case(self):
params = [['x', 'input'], ['min'], ['max']]

# Generate all valid combinations
def generate_cases(param_groups, case_list):
from itertools import product

for combo in product(*[[None, *names] for names in param_groups]):
args = ['pos' if p is None else 'kw' for p in combo]
if args == sorted(args, key=lambda x: x != 'pos'):
case_list.append(combo)

# paddle.clip()
self.test_cases = []
generate_cases(params, self.test_cases)
# x.clip()
self.tensor_test_cases = []
generate_cases(params[1:], self.tensor_test_cases)

def _build_args_kwargs(self, param_names, params):
args = []
kwargs = {}
for name, param in zip(param_names, params):
if name is None:
args.append(param)
else:
kwargs[name] = param
return args, kwargs

def test_dygraph_compatibility(self):
with dygraph_guard():
for place in self.places:
paddle.device.set_device(place)
x = paddle.to_tensor(self.np_input)
# paddle.
for param_names in self.test_cases:
args, kwargs = self._build_args_kwargs(
param_names, (x, self.min_val, self.max_val)
)
for out_flag in [False, True]:
if out_flag:
kwargs['out'] = paddle.empty([])
self.func(*args, **kwargs)
out = kwargs["out"]
else:
out = self.func(*args, **kwargs)
np.testing.assert_array_equal(self.np_out, out.numpy())
# paddle.Tensor.
for param_names in self.tensor_test_cases:
args, kwargs = self._build_args_kwargs(
param_names, (self.min_val, self.max_val)
)
out = x.clip(*args, **kwargs)
np.testing.assert_array_equal(self.np_out, out.numpy())

def test_dygraph_out(self):
def run_clip(test_type):
x = paddle.to_tensor(self.np_input)
x.stop_gradient = False
out = (
paddle.zeros(self.np_out.shape)
if test_type in ["with_out", "both"]
else None
)
if test_type == "return":
out = paddle.clip(x, self.min_val, self.max_val)
elif test_type == "with_out":
paddle.clip(x, self.min_val, self.max_val, out=out)
elif test_type == "both":
out = paddle.clip(x, self.min_val, self.max_val, out=out)
else:
raise ValueError(f"Invalid test_mode: {test_type}")

expected = paddle._C_ops.clip(x, self.min_val, self.max_val)
np.testing.assert_array_equal(out.numpy(), expected.numpy())
loss = out.sum().astype('float32')
loss.backward()
return out, x.grad

def assert_outputs_equal(outputs, rtol: float = 1e-10):
for out in outputs[1:]:
np.testing.assert_allclose(
outputs[0].numpy(), out.numpy(), rtol=rtol
)

with dygraph_guard():
for place in self.places:
paddle.device.set_device(place)
out1, grad1 = run_clip("return")
out2, grad2 = run_clip("with_out")
out3, grad3 = run_clip("both")

assert_outputs_equal([out1, out2, out3])
if (
grad1 is not None
and grad2 is not None
and grad3 is not None
):
assert_outputs_equal([grad1, grad2, grad3])

def test_static_compatibility(self):
with static_guard():
for place in self.places:
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.base.program_guard(main, startup):
x = paddle.static.data(
name="x", shape=self.shape, dtype=self.dtype
)
# paddle.
for param_names in self.test_cases:
args, kwargs = self._build_args_kwargs(
param_names, (x, self.min_val, self.max_val)
)
out = self.func(*args, **kwargs)

exe = paddle.base.Executor(place)
fetches = exe.run(
main,
feed={"x": self.np_input},
fetch_list=[out],
)
np.testing.assert_array_equal(self.np_out, fetches[0])
# paddle.Tensor.
for param_names in self.tensor_test_cases:
args, kwargs = self._build_args_kwargs(
param_names, (self.min_val, self.max_val)
)

out = x.clip(*args, **kwargs)

exe = paddle.base.Executor(place)
fetches = exe.run(
main,
feed={"x": self.np_input},
fetch_list=[out],
)
np.testing.assert_array_equal(self.np_out, fetches[0])


if __name__ == '__main__':
unittest.main()
Loading