diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index cd55d1372a42d3..3b329fac4fdd60 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -3714,10 +3714,33 @@ def log10_(x: Tensor, name: str | None = None) -> Tensor: return _C_ops.log10_(x) +def check_clip_tensor(c_x, value, re_value, value_type, name): + if value is None: + value = paddle.full_like(c_x, re_value, value_type) + else: + if isinstance(value, (Variable, paddle.pir.Value, paddle.Tensor)): + if len(value.shape) == 1 and value.shape[-1] == 0: + raise ValueError( + f"The {name} dimension should be equal to the inner dimension of the x, but the {name} dimension is {value.shape}" + ) + elif ( + len(value.shape) != 0 + and value.shape != c_x.shape[-len(value.shape) :] + and value.shape != [1] + and value.shape != (1,) + ): + raise ValueError( + f"The {name} dimension should be equal to the inner dimension of the x, but the {name} dimension is {value.shape} and the x dimension is {c_x.shape[-len(value.shape):]}." + ) + else: + value = paddle.full_like(c_x, value, value_type) + return value + + def clip( x: Tensor, - min: float | None = None, - max: float | None = None, + min: float | Tensor | None = None, + max: float | Tensor | None = None, name: str | None = None, ) -> Tensor: """ @@ -3761,84 +3784,119 @@ def clip( if x_dtype == 'paddle.int32': min_ = np.iinfo(np.int32).min max_ = np.iinfo(np.int32).max - 2**7 + value_dtype = 'int32' elif x_dtype == 'paddle.int64': min_ = np.iinfo(np.int64).min max_ = np.iinfo(np.int64).max - 2**39 + value_dtype = 'int64' elif x_dtype == 'paddle.float16': min_ = float(np.finfo(np.float16).min) max_ = float(np.finfo(np.float16).max) + value_dtype = 'float16' else: min_ = float(np.finfo(np.float32).min) max_ = float(np.finfo(np.float32).max) + value_dtype = 'float32' + + if ( + isinstance(min, (Variable, paddle.pir.Value, paddle.Tensor)) + and ( + (len(min.shape) == 1 and min.shape[-1] not in [1, 0]) + or len(min.shape) > 1 + ) + ) or ( + isinstance(max, (Variable, paddle.pir.Value, paddle.Tensor)) + and ( + (len(max.shape) == 1 and max.shape[-1] not in [1, 0]) + or len(max.shape) > 1 + ) + ): + min_n = check_clip_tensor(x, min, min_, value_dtype, 'min') + max_n = check_clip_tensor(x, max, max_, value_dtype, 'max') + + min_n = ( + paddle.broadcast_to(min_n, x.shape) if min_n.shape != x.shape else min_n + ) + max_n = ( + paddle.broadcast_to(max_n, x.shape) if max_n.shape != x.shape else max_n + ) + + output_min = paddle.where(x < min_n, min_n, x) + output = paddle.where(output_min > max_n, max_n, output_min) + return output - if in_dynamic_or_pir_mode(): - if isinstance(min, Variable): - min = min.item(0) - if isinstance(max, Variable): - max = max.item(0) - min = min_ if min is None else min - max = max_ if max is None else max - return _C_ops.clip(x, min, max) else: - if min is not None: - check_type(min, 'min', (float, int, Variable), 'clip') + if in_dynamic_or_pir_mode(): if isinstance(min, Variable): - check_dtype( - min.dtype, - 'min', - ['float16', 'float32', 'float64', 'int32', 'uint16'], - 'clip', - '(When the type of min in clip is Variable.)', - ) - if max is not None: - check_type(max, 'max', (float, int, Variable), 'clip') + min = min.item(0) if isinstance(max, Variable): - check_dtype( - max.dtype, - 'max', - ['float16', 'float32', 'float64', 'int32', 'uint16'], - 'clip', - '(When the type of max in clip is Variable.)', - ) - - check_variable_and_dtype( - x, - 'x', - ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], - 'clip', - ) + max = max.item(0) + min = min_ if min is None else min + max = max_ if max is None else max + return _C_ops.clip(x, min, max) + else: + if min is not None: + check_type(min, 'min', (float, int, Variable), 'clip') + if isinstance(min, Variable): + check_dtype( + min.dtype, + 'min', + ['float16', 'float32', 'float64', 'int32', 'uint16'], + 'clip', + '(When the type of min in clip is Variable.)', + ) + if max is not None: + check_type(max, 'max', (float, int, Variable), 'clip') + if isinstance(max, Variable): + check_dtype( + max.dtype, + 'max', + ['float16', 'float32', 'float64', 'int32', 'uint16'], + 'clip', + '(When the type of max in clip is Variable.)', + ) - inputs = {'X': x} - attrs = {'min': min_, 'max': max_} + check_variable_and_dtype( + x, + 'x', + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], + 'clip', + ) - if isinstance(min, Variable): - min.stop_gradient = True - inputs['Min'] = min - elif min is not None: - attrs['min'] = min + inputs = {'X': x} + attrs = {'min': min_, 'max': max_} - if isinstance(max, Variable): - max.stop_gradient = True - inputs['Max'] = max - elif max is not None: - attrs['max'] = max + if isinstance(min, Variable): + min.stop_gradient = True + inputs['Min'] = min + elif min is not None: + attrs['min'] = min - helper = LayerHelper('clip', **locals()) - output = helper.create_variable_for_type_inference( - dtype=helper.input_dtype('x') - ) - helper.append_op( - type='clip', inputs=inputs, outputs={'Out': [output]}, attrs=attrs - ) + if isinstance(max, Variable): + max.stop_gradient = True + inputs['Max'] = max + elif max is not None: + attrs['max'] = max + + helper = LayerHelper('clip', **locals()) + output = helper.create_variable_for_type_inference( + dtype=helper.input_dtype('x') + ) + helper.append_op( + type='clip', + inputs=inputs, + outputs={'Out': [output]}, + attrs=attrs, + ) - return output + return output @inplace_apis_in_dygraph_only def clip_( x: Tensor, - min: float | None = None, - max: float | None = None, + min: float | Tensor | None = None, + max: float | Tensor | None = None, name: str | None = None, ) -> Tensor: """ @@ -3847,15 +3905,37 @@ def clip_( """ fmin = float(np.finfo(np.float32).min) fmax = float(np.finfo(np.float32).max) - if isinstance(min, Variable): - min = min.item(0) - if isinstance(max, Variable): - max = max.item(0) min = fmin if min is None else min max = fmax if max is None else max if in_dynamic_mode(): - return _C_ops.clip_(x, min, max) + if ( + isinstance(min, (Variable, paddle.pir.Value, paddle.Tensor)) + and ( + (len(min.shape) == 1 and min.shape[-1] not in [1, 0]) + or len(min.shape) > 1 + ) + ) or ( + isinstance(max, (Variable, paddle.pir.Value, paddle.Tensor)) + and ( + (len(max.shape) == 1 and max.shape[-1] not in [1, 0]) + or len(max.shape) > 1 + ) + ): + max = check_clip_tensor(x, max, fmin, x.dtype, 'max') + min = check_clip_tensor(x, min, fmin, x.dtype, 'min') + + max_expand = ( + paddle.broadcast_to(max, x.shape) if max.shape != x.shape else max + ) + min_expand = ( + paddle.broadcast_to(min, x.shape) if min.shape != x.shape else min + ) + + paddle.where_(x > min_expand, x, min_expand) + return paddle.where_(x < max_expand, x, max_expand) + else: + return _C_ops.clip_(x, min, max) def trace( diff --git a/test/legacy_test/test_clip_tensor_op.py b/test/legacy_test/test_clip_tensor_op.py new file mode 100644 index 00000000000000..d7919c6bfe908a --- /dev/null +++ b/test/legacy_test/test_clip_tensor_op.py @@ -0,0 +1,200 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import base + + +class TestClipTensor(unittest.TestCase): + # def test_static_clip(self): + # data_shape = [1, 2, 3, 4] + # self.place = ( + # base.CUDAPlace(0) + # if base.core.is_compiled_with_cuda() + # else base.CPUPlace() + # ) + # data = np.random.random(data_shape).astype('float32') + # min_data = np.random.random(data_shape[-2:]).astype('float32') + # max_data = np.random.random(data_shape[-3:]).astype('float32') + # paddle.enable_static() + # with paddle.static.program_guard(paddle.static.Program()): + # x = paddle.static.data(name='x', shape=data_shape, dtype='float32') + # min = paddle.static.data( + # name='min', shape=data_shape[-2:], dtype='float32' + # ) + # max = paddle.static.data( + # name='max', shape=data_shape[-3:], dtype='float32' + # ) + # out = paddle.clip(x, min, max) + # exe = base.Executor(self.place) + # res = exe.run( + # feed={ + # "x": data, + # 'min': min_data, + # 'max': max_data, + # }, + # fetch_list=[out], + # ) + # res_np = np.clip(data, min_data, max_data) + # np.testing.assert_allclose(res_np, res[0], rtol=1e-05) + # paddle.disable_static() + + # data_shape = [1, 2, 3, 4] + # self.place = ( + # base.CUDAPlace(0) + # if base.core.is_compiled_with_cuda() + # else base.CPUPlace() + # ) + # data = np.random.random(data_shape).astype('float32') + # min_data = np.random.random(data_shape[-2:]).astype('float32') + # max_data = np.random.random(data_shape[-3:]).astype('float32') + # paddle.enable_static() + # with paddle.static.program_guard(paddle.static.Program()): + # x = paddle.static.data(name='x', shape=data_shape, dtype='float32') + # min = paddle.static.data( + # name='min', shape=data_shape[-2:], dtype='float32' + # ) + # max = 5.0 + # out = paddle.clip(x, min, max) + # exe = base.Executor(self.place) + # res = exe.run( + # feed={ + # "x": data, + # 'min': min_data, + # 'max': 5.0, + # }, + # fetch_list=[out], + # ) + # res_np = np.clip(data, min_data, 5.0) + # np.testing.assert_allclose(res_np, res[0], rtol=1e-05) + # paddle.disable_static() + + # data_shape = [1, 2, 3, 4] + # self.place = ( + # base.CUDAPlace(0) + # if base.core.is_compiled_with_cuda() + # else base.CPUPlace() + # ) + # data = np.random.random(data_shape).astype('float32') + # min_data = np.random.random(data_shape[-2:]).astype('float32') + # max_data = float(np.finfo(np.float32).max) + # paddle.enable_static() + # with paddle.static.program_guard(paddle.static.Program()): + # x = paddle.static.data(name='x', shape=data_shape, dtype='float32') + # min = paddle.static.data( + # name='min', shape=data_shape[-2:], dtype='float32' + # ) + # out = paddle.clip(x, min) + # exe = base.Executor(self.place) + # res = exe.run( + # feed={"x": data, 'min': min_data}, + # fetch_list=[out], + # ) + # res_np = np.clip(data, min_data, max_data) + # np.testing.assert_allclose(res_np, res[0], rtol=1e-05) + # paddle.disable_static() + + def test_dygraph_clip(self): + self.place = ( + base.CUDAPlace(0) + if base.core.is_compiled_with_cuda() + else base.CPUPlace() + ) + paddle.disable_static(self.place) + data_shape = [1, 2, 3, 4] + data = np.random.random(data_shape).astype('float32') + min_data = np.random.random(data_shape[-2:]).astype('float32') + max_data = np.random.random(data_shape[-3:]).astype('float32') + out_np = np.clip(data, min_data, max_data) + data = paddle.to_tensor(data) + min_data = paddle.to_tensor(min_data) + max_data = paddle.to_tensor(max_data) + out = paddle.clip(data, min_data, max_data) + np.testing.assert_allclose(out.numpy(), out_np, rtol=1e-05) + + data_shape = [1, 2, 3, 4] + data = np.random.random(data_shape).astype('float32') + min_data = np.random.random(data_shape[-2:]).astype('float32') + max_data = np.random.random(data_shape[-1:]).astype('float32') + out_np = np.clip(data, min_data, max_data) + data = paddle.to_tensor(data) + min_data = paddle.to_tensor(min_data) + max_data = paddle.to_tensor(max_data) + out = paddle.clip(data, min_data, max_data) + np.testing.assert_allclose(out.numpy(), out_np, rtol=1e-05) + + data_shape = [1, 2, 3, 4] + data = np.random.random(data_shape).astype('int32') + min_data = np.random.random(data_shape[-2:]).astype('int32') + max_data = 5 + out_np = np.clip(data, min_data, max_data) + data = paddle.to_tensor(data) + min_data = paddle.to_tensor(min_data) + max_data = paddle.to_tensor([max_data], dtype='int32') + out = paddle.clip(data, min_data, max_data) + np.testing.assert_allclose(out.numpy(), out_np, rtol=1e-05) + + data_shape = [1, 2, 3, 4] + data = np.random.random(data_shape).astype('float32') + min_data = np.random.random(data_shape[-2:]).astype('float32') + max_data = float(np.finfo(np.float32).max) + out_np = np.clip(data, min_data, max_data) + data = paddle.to_tensor(data) + min_data = paddle.to_tensor(min_data) + out = paddle.clip(data, min_data) + np.testing.assert_allclose(out.numpy(), out_np, rtol=1e-05) + + data_shape = [1, 2, 3, 4] + data = np.random.random(data_shape).astype('float32') + min_data = np.random.random(data_shape[-2:]).astype('float32') + max_data = 5 + out_np = np.clip(data, min_data, max_data) + data = paddle.to_tensor(data) + min_data = paddle.to_tensor(min_data) + out = paddle.clip(data, min_data, max_data) + np.testing.assert_allclose(out.numpy(), out_np, rtol=1e-05) + + paddle.enable_static() + + def test_shapeerror_clip(self): + data_shape = [1, 9, 9, 4] + data = np.random.random(data_shape).astype('float32') + data = paddle.to_tensor(data) + with self.assertRaises(ValueError): + paddle.clip(data, paddle.rand([2])) + + data_shape = [1, 9, 9, 4] + data = np.random.random(data_shape).astype('float32') + data = paddle.to_tensor(data) + with self.assertRaises(ValueError): + paddle.clip( + data, min=paddle.to_tensor([1, 2, 3, 4]), max=paddle.rand([0]) + ) + + def test_tensor_clip_(self): + data_shape = [1, 9, 9, 4] + data = paddle.to_tensor(np.random.random(data_shape).astype('float32')) + min = paddle.to_tensor( + np.random.random(data_shape[-2:]).astype('float32') + ) + max = min + 5 + data.clip_(min, max) + + +if __name__ == '__main__': + unittest.main()