diff --git a/paddle/phi/kernels/clip_grad_kernel.h b/paddle/phi/kernels/clip_grad_kernel.h index bc6245ce90eabe..d09ef0081c7cf8 100644 --- a/paddle/phi/kernels/clip_grad_kernel.h +++ b/paddle/phi/kernels/clip_grad_kernel.h @@ -24,8 +24,8 @@ template void ClipGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out_grad, - const Scalar& min, - const Scalar& max, + const DenseTensor& min, + const DenseTensor& max, DenseTensor* x_grad); } // namespace phi diff --git a/paddle/phi/kernels/clip_kernel.h b/paddle/phi/kernels/clip_kernel.h index 14ac8342e03bcf..816df32c93ca9d 100644 --- a/paddle/phi/kernels/clip_kernel.h +++ b/paddle/phi/kernels/clip_kernel.h @@ -14,18 +14,16 @@ #pragma once -#include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/device_context.h" -#include "paddle/phi/core/selected_rows.h" namespace phi { template void ClipKernel(const Context& dev_ctx, const DenseTensor& x, - const Scalar& min, - const Scalar& max, + const DenseTensor& min, + const DenseTensor& max, DenseTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/cpu/clip_grad_kernel.cc b/paddle/phi/kernels/cpu/clip_grad_kernel.cc index 89a14af10d16c5..11714b5d6691a3 100644 --- a/paddle/phi/kernels/cpu/clip_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/clip_grad_kernel.cc @@ -16,7 +16,29 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/clip_grad_kernel_impl.h" + +namespace phi { + +template +void ClipGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const DenseTensor& min, + const DenseTensor& max, + DenseTensor* x_grad) { + int64_t numel = out_grad.numel(); + auto *d_x_data = dev_ctx.template Alloc(x_grad); + const auto* x_data = x.data(); + const auto* min_data = min.data(); + const auto* max_data = max.data(); + auto* d_out_data = out_grad.data(); + + for (int i = 0; i < numel; i++) { + d_x_data[i] = ( x_data[i] > min_data[i] && x_data[i] < max_data[i]) ? d_out_data[i] : static_cast(0); + } +} + +} // namespace phi PD_REGISTER_KERNEL(clip_grad, CPU, diff --git a/paddle/phi/kernels/cpu/clip_kernel.cc b/paddle/phi/kernels/cpu/clip_kernel.cc index bcbb85279277e5..ee14ca08a472f4 100644 --- a/paddle/phi/kernels/cpu/clip_kernel.cc +++ b/paddle/phi/kernels/cpu/clip_kernel.cc @@ -16,7 +16,36 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/clip_kernel_impl.h" + +namespace phi { + +template +void ClipKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& min, + const DenseTensor& max, + DenseTensor* out) { + const T* x_data = x.data(); + const T* min_data = min.data(); + const T* max_data = max.data(); + auto x_numel = x.numel(); + + T* out_data = ctx.template Alloc(out); + + for (int i = 0; i < x_numel; i++) { + PADDLE_ENFORCE_LE( + min_data[i], + max_data[i], + errors::InvalidArgument("max should be greater than or equal to min. " + "But received min = %f, max = %f", + static_cast(min_data[i]), + static_cast(max_data[i]))); + + out_data[i] = x_data[i] < min_data[i] ? min_data[i] : x_data[i] > max_data[i] ? max_data[i] : x_data[i]; + } +} + +} // namespace phi PD_REGISTER_KERNEL( clip, CPU, ALL_LAYOUT, phi::ClipKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/gpu/clip_grad_kernel.cu b/paddle/phi/kernels/gpu/clip_grad_kernel.cu index 60d311a2555a0d..c6c6475ead4058 100644 --- a/paddle/phi/kernels/gpu/clip_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/clip_grad_kernel.cu @@ -14,10 +14,36 @@ #include "paddle/phi/kernels/clip_grad_kernel.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/clip_grad_kernel_impl.h" + +namespace phi { + +template +class ClipGradGPUFunctor { + inline HOSTDEVICE T operator()(const T x, const T y, const T min_, const T max_) const { + return (y > min_ && y < max_) ? x : static_cast(0); + } +}; + +template +void ClipGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const DenseTensor& min, + const DenseTensor& max, + DenseTensor* x_grad) { + std::vector ins = {&out_grad, &x, &min, &max}; + std::vector outs = {x_grad}; + auto functor = ClipGradGPUFunctor(); + dev_ctx.template Alloc(x_grad); + phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); +} + +} // namespace phi PD_REGISTER_KERNEL(clip_grad, GPU, diff --git a/paddle/phi/kernels/gpu/clip_kernel.cu b/paddle/phi/kernels/gpu/clip_kernel.cu index e8d519a5d3a2b9..3a7c24abd351bc 100644 --- a/paddle/phi/kernels/gpu/clip_kernel.cu +++ b/paddle/phi/kernels/gpu/clip_kernel.cu @@ -17,7 +17,34 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/clip_kernel_impl.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" + +namespace phi { + +// Cond +template +struct ClipGPUFunctor { + inline HOSTDEVICE T operator()(const T x, const T min_, const T max_) const { + return x < min_ ? min_ : x > max_ ? max_ : x; + } +}; + +template +void ClipKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& min, + const DenseTensor& max, + DenseTensor* out) { + std::vector ins = {&x, &min, &max}; + std::vector outs = {out}; + ctx.template Alloc(out); + + ClipGPUFunctor func; + funcs::ElementwiseKernel, 1>(ctx, ins, &outs, func); +} + +} // namespace phi PD_REGISTER_KERNEL(clip, GPU, diff --git a/paddle/phi/kernels/impl/clip_kernel_impl.h b/paddle/phi/kernels/impl/clip_kernel_impl.h index 7d327ef5c5dfaf..3a5c415b81c7ea 100644 --- a/paddle/phi/kernels/impl/clip_kernel_impl.h +++ b/paddle/phi/kernels/impl/clip_kernel_impl.h @@ -43,8 +43,8 @@ void ClipKernel(const Context& dev_ctx, const Scalar& min, const Scalar& max, DenseTensor* out) { - auto max_ = max.to(); - auto min_ = min.to(); + auto max_ = max.data.to(); + auto min_ = min.data.to(); PADDLE_ENFORCE_LE( min_, diff --git a/paddle/phi/kernels/onednn/clip_kernel.cc b/paddle/phi/kernels/onednn/clip_kernel.cc index 0accedb1724f29..2fda1351be6aa9 100644 --- a/paddle/phi/kernels/onednn/clip_kernel.cc +++ b/paddle/phi/kernels/onednn/clip_kernel.cc @@ -21,13 +21,13 @@ namespace phi { template void ClipKernel(const Context& dev_ctx, const DenseTensor& x, - const Scalar& min, - const Scalar& max, + const DenseTensor& min, + const DenseTensor& max, DenseTensor* out) { const auto& onednn_engine = dev_ctx.GetEngine(); funcs::ClipOneDNNHandler handler( - min, max, onednn_engine, dev_ctx.GetPlace(), &x); + &min, &max, onednn_engine, dev_ctx.GetPlace(), &x); auto src_memory_p = handler.AcquireSrcMemory(&x); auto dst_memory_p = handler.AcquireDstMemory(out); diff --git a/paddle/phi/kernels/xpu/clip_kernel.cc b/paddle/phi/kernels/xpu/clip_kernel.cc index 827882c1eb84b4..ed07ec0ab9b11b 100644 --- a/paddle/phi/kernels/xpu/clip_kernel.cc +++ b/paddle/phi/kernels/xpu/clip_kernel.cc @@ -25,8 +25,8 @@ namespace phi { template void ClipKernel(const Context& dev_ctx, const DenseTensor& x, - const Scalar& min, - const Scalar& max, + const DenseTensor& min, + const DenseTensor& max, DenseTensor* out) { dev_ctx.template Alloc(out); using XPUDataType = typename XPUTypeTrait::Type; @@ -36,8 +36,8 @@ void ClipKernel(const Context& dev_ctx, x_data, out_data, x.numel(), - static_cast(min.to()), - static_cast(max.to())); + static_cast(min.data.to()), + static_cast(max.data.to())); PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index d4c6db2f59528a..70c3047b6eac86 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -380,8 +380,8 @@ func : cholesky_solve_grad - backward_op : clip_double_grad - forward : clip_grad (Tensor x, Tensor grad_out, Scalar min = 0., Scalar max = 0.) -> Tensor(grad_x) - args : (Tensor x, Tensor grad_x_grad, Scalar min = 0., Scalar max = 0.) + forward : clip_grad (Tensor x, Tensor grad_out, Tensor min, Tensor max) -> Tensor(grad_x) + args : (Tensor x, Tensor grad_x_grad, Tensor min, Tensor max) output : Tensor(grad_out_grad) infer_meta : func : UnchangedInferMeta @@ -391,8 +391,8 @@ data_type : x - backward_op : clip_grad - forward : clip (Tensor x, Scalar min, Scalar max) -> Tensor(out) - args : (Tensor x, Tensor out_grad, Scalar min = 0., Scalar max = 0.) + forward : clip (Tensor x, Tensor min, Tensor max) -> Tensor(out) + args : (Tensor x, Tensor out_grad, Tensor min, Tensor max) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index 28a084353391f4..ee9a4afa60583b 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -589,16 +589,9 @@ - op : clip backward : clip_grad, clip_double_grad inputs : - x : X + {x : X, min : Min, max : Max} outputs : out : Out - scalar : - min : - data_type : float - tensor_name : Min - max : - data_type : float - tensor_name : Max extra : attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"] diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 93ab29ae5928c5..83dddc7acd6747 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -950,7 +950,7 @@ interfaces : paddle::dialect::InferSymbolicShapeInterface - op : clip - args : (Tensor x, Scalar(float) min, Scalar(float) max) + args : (Tensor x, Tensor min, Tensor max) output : Tensor(out) inplace : (x -> out) infer_meta : diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index cd55d1372a42d3..954f780be56bc2 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -3716,8 +3716,8 @@ def log10_(x: Tensor, name: str | None = None) -> Tensor: def clip( x: Tensor, - min: float | None = None, - max: float | None = None, + min: float | Tensor | None = None, + max: float | Tensor | None = None, name: str | None = None, ) -> Tensor: """ @@ -3771,35 +3771,44 @@ def clip( min_ = float(np.finfo(np.float32).min) max_ = float(np.finfo(np.float32).max) + min = min_ if min is None else min + max = max_ if max is None else max + + if not isinstance(min, Variable): + min = paddle.full_like(x, min, x.dtype) + if not isinstance(max, Variable): + max = paddle.full_like(x, max, x.dtype) + + if min.shape != x.shape[-len(min.shape)] or min.shape[-1] == 0: + raise ValueError( + f"The min dimension should be equal to the inner dimension of the x, but the min dimension is {min.shape}" + ) + + if max.shape != x.shape[-len(max.shape)] or max.shape[-1] == 0: + raise ValueError( + f"The max dimension should be equal to the inner dimension of the x, but the max dimension is {max.shape}" + ) + + broadcast_min = paddle.broadcast_to(min, x.shape) if min.shape != x.shape else min + broadcast_max = paddle.broadcast_to(max, x.shape) if max.shape != x.shape else max + if in_dynamic_or_pir_mode(): - if isinstance(min, Variable): - min = min.item(0) - if isinstance(max, Variable): - max = max.item(0) - min = min_ if min is None else min - max = max_ if max is None else max - return _C_ops.clip(x, min, max) + return _C_ops.clip(x, broadcast_min, broadcast_max) else: - if min is not None: - check_type(min, 'min', (float, int, Variable), 'clip') - if isinstance(min, Variable): - check_dtype( - min.dtype, - 'min', - ['float16', 'float32', 'float64', 'int32', 'uint16'], - 'clip', - '(When the type of min in clip is Variable.)', - ) - if max is not None: - check_type(max, 'max', (float, int, Variable), 'clip') - if isinstance(max, Variable): - check_dtype( - max.dtype, - 'max', - ['float16', 'float32', 'float64', 'int32', 'uint16'], - 'clip', - '(When the type of max in clip is Variable.)', - ) + check_variable_and_dtype( + broadcast_min, + 'min', + ['float16', 'float32', 'float64', 'int32', 'uint16'], + 'clip', + '(When the type of min in clip is Variable.)', + ) + check_variable_and_dtype( + broadcast_max, + 'max', + ['float16', 'float32', 'float64', 'int32', 'uint16'], + 'clip', + '(When the type of max in clip is Variable.)', + ) check_variable_and_dtype( x, @@ -3809,27 +3818,18 @@ def clip( ) inputs = {'X': x} - attrs = {'min': min_, 'max': max_} - if isinstance(min, Variable): - min.stop_gradient = True - inputs['Min'] = min - elif min is not None: - attrs['min'] = min + min.stop_gradient = True + inputs['Min'] = min - if isinstance(max, Variable): - max.stop_gradient = True - inputs['Max'] = max - elif max is not None: - attrs['max'] = max + max.stop_gradient = True + inputs['Max'] = max helper = LayerHelper('clip', **locals()) output = helper.create_variable_for_type_inference( dtype=helper.input_dtype('x') ) - helper.append_op( - type='clip', inputs=inputs, outputs={'Out': [output]}, attrs=attrs - ) + helper.append_op(type='clip', inputs=inputs, outputs={'Out': [output]}) return output @@ -3837,8 +3837,8 @@ def clip( @inplace_apis_in_dygraph_only def clip_( x: Tensor, - min: float | None = None, - max: float | None = None, + min: float | Tensor | None = None, + max: float | Tensor | None = None, name: str | None = None, ) -> Tensor: """ @@ -3847,15 +3847,30 @@ def clip_( """ fmin = float(np.finfo(np.float32).min) fmax = float(np.finfo(np.float32).max) - if isinstance(min, Variable): - min = min.item(0) - if isinstance(max, Variable): - max = max.item(0) + min = fmin if min is None else min max = fmax if max is None else max + if not isinstance(min, Variable): + min = paddle.full_like(x, min, x.dtype) + if not isinstance(max, Variable): + max = paddle.full_like(x, max, x.dtype) + + if min.shape != x.shape[-len(min.shape)] or min.shape[-1] == 0: + raise ValueError( + f"The min dimension should be equal to the inner dimension of the x, but the min dimension is {min.shape}" + ) + + if max.shape != x.shape[-len(max.shape)] or max.shape[-1] == 0: + raise ValueError( + f"The max dimension should be equal to the inner dimension of the x, but the max dimension is {max.shape}" + ) + + broadcast_min = paddle.broadcast_to(min, x.shape) if min.shape != x.shape else min + broadcast_max = paddle.broadcast_to(max, x.shape) if max.shape != x.shape else max + if in_dynamic_mode(): - return _C_ops.clip_(x, min, max) + return _C_ops.clip_(x, broadcast_min, broadcast_max) def trace( diff --git a/test/legacy_test/test_clip_tensor_op.py b/test/legacy_test/test_clip_tensor_op.py new file mode 100644 index 00000000000000..4ccd693efd6fd2 --- /dev/null +++ b/test/legacy_test/test_clip_tensor_op.py @@ -0,0 +1,330 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import base + + +class TestClipTensor(unittest.TestCase): + def test_static_clip(self): + data_shape = [1, 2, 3, 4] + self.place = ( + base.CUDAPlace(0) + if base.core.is_compiled_with_cuda() + else base.CPUPlace() + ) + data = np.random.uniform(-2.0, 2.0, data_shape).astype('float32') + min_data = np.random.uniform(-2.5, -1.5, data_shape[-2:]).astype( + 'float32' + ) + max_data = np.random.uniform(1.5, 2.5, data_shape[-3:]).astype( + 'float32' + ) + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name='x', shape=data_shape, dtype='float32') + min = paddle.static.data( + name='min', shape=data_shape[-2:], dtype='float32' + ) + max = paddle.static.data( + name='max', shape=data_shape[-3:], dtype='float32' + ) + out = paddle.clip(x, min, max) + exe = base.Executor(self.place) + res = exe.run( + feed={ + "x": data, + 'min': min_data, + 'max': max_data, + }, + fetch_list=[out], + ) + res_np = np.clip(data, min_data, max_data) + np.testing.assert_allclose(res_np, res[0], rtol=1e-05) + paddle.disable_static() + + data_shape = [1, 2, 3, 4] + self.place = ( + base.CUDAPlace(0) + if base.core.is_compiled_with_cuda() + else base.CPUPlace() + ) + + data = np.random.uniform(-2.0, 7.0, data_shape).astype('float32') + min_data = np.random.uniform(-2.5, -1.5, data_shape[-2:]).astype( + 'float32' + ) + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name='x', shape=data_shape, dtype='float32') + min = paddle.static.data( + name='min', shape=data_shape[-2:], dtype='float32' + ) + max = 5.0 + out = paddle.clip(x, min, max) + exe = base.Executor(self.place) + res = exe.run( + feed={ + "x": data, + 'min': min_data, + 'max': 5.0, + }, + fetch_list=[out], + ) + res_np = np.clip(data, min_data, 5.0) + np.testing.assert_allclose(res_np, res[0], rtol=1e-05) + paddle.disable_static() + + data_shape = [1, 2, 3, 4] + self.place = ( + base.CUDAPlace(0) + if base.core.is_compiled_with_cuda() + else base.CPUPlace() + ) + data = np.random.uniform(-2.0, 2.0, data_shape).astype('float32') + min_data = np.random.uniform(-2.5, -1.5, data_shape[-2:]).astype( + 'float32' + ) + max_data = float(np.finfo(np.float32).max) + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name='x', shape=data_shape, dtype='float32') + min = paddle.static.data( + name='min', shape=data_shape[-2:], dtype='float32' + ) + out = paddle.clip(x, min) + exe = base.Executor(self.place) + res = exe.run( + feed={"x": data, 'min': min_data}, + fetch_list=[out], + ) + res_np = np.clip(data, min_data, max_data) + np.testing.assert_allclose(res_np, res[0], rtol=1e-05) + paddle.disable_static() + + data_shape = [1, 2, 3, 4] + self.place = ( + base.CUDAPlace(0) + if base.core.is_compiled_with_cuda() + else base.CPUPlace() + ) + data = np.random.uniform(-2.0, 2.0, data_shape).astype('float32') + min_data = np.random.uniform(-2.5, -1.5, []).astype('float32') + max_data = np.random.uniform(1.5, 2.5, []).astype('float32') + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name='x', shape=data_shape, dtype='float32') + min = paddle.static.data(name='min', shape=[], dtype='float32') + max = paddle.static.data(name='max', shape=[], dtype='float32') + out = paddle.clip(x, min, max) + exe = base.Executor(self.place) + res = exe.run( + feed={"x": data, 'min': min_data, 'max': max_data}, + fetch_list=[out], + ) + res_np = np.clip(data, min_data, max_data) + np.testing.assert_allclose(res_np, res[0], rtol=1e-05) + paddle.disable_static() + + data_shape = [1, 2, 3, 4] + self.place = ( + base.CUDAPlace(0) + if base.core.is_compiled_with_cuda() + else base.CPUPlace() + ) + data = np.random.uniform(-2.0, 2.0, data_shape).astype('float32') + min_data = np.random.uniform(-2.5, -1.5, [1]).astype('float32') + max_data = np.random.uniform(1.5, 2.5, [1]).astype('float32') + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name='x', shape=data_shape, dtype='float32') + min = paddle.static.data(name='min', shape=[1], dtype='float32') + max = paddle.static.data(name='max', shape=[1], dtype='float32') + out = paddle.clip(x, min, max) + exe = base.Executor(self.place) + res = exe.run( + feed={"x": data, 'min': min_data, 'max': max_data}, + fetch_list=[out], + ) + res_np = np.clip(data, min_data, max_data) + np.testing.assert_allclose(res_np, res[0], rtol=1e-05) + paddle.disable_static() + + if base.core.is_compiled_with_cuda(): + paddle.enable_static() + data_shape = [1, 9, 9, 4] + data = np.random.random(data_shape).astype('float16') + + with paddle.static.program_guard(paddle.static.Program()): + images = paddle.static.data( + name='image1', shape=data_shape, dtype='float16' + ) + min = paddle.static.data( + name='min1', shape=[1], dtype='float16' + ) + max = paddle.static.data( + name='max1', shape=[1], dtype='float16' + ) + out = paddle.clip(images, min, max) + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + res1 = exe.run( + feed={ + "image1": data, + "min1": np.array([0.2]).astype('float16'), + "max1": np.array([0.8]).astype('float16'), + }, + fetch_list=[out], + ) + paddle.disable_static() + + def test_dygraph_clip(self): + self.place = ( + base.CUDAPlace(0) + if base.core.is_compiled_with_cuda() + else base.CPUPlace() + ) + paddle.disable_static(self.place) + data_shape = [1, 2, 3, 4] + data = np.random.uniform(-2.0, 2.0, data_shape).astype('float32') + min_data = np.random.uniform(-2.5, -1.5, data_shape[-2:]).astype( + 'float32' + ) + max_data = np.random.uniform(1.5, 2.5, data_shape[-3:]).astype( + 'float32' + ) + out_np = np.clip(data, min_data, max_data) + data = paddle.to_tensor(data) + min_data = paddle.to_tensor(min_data) + max_data = paddle.to_tensor(max_data) + out = paddle.clip(data, min_data, max_data) + np.testing.assert_allclose(out.numpy(), out_np, rtol=1e-05) + + data_shape = [1, 2, 3, 4] + data = np.random.uniform(-2.0, 2.0, data_shape).astype('float32') + min_data = np.random.uniform(-2.5, -1.5, data_shape[-2:]).astype( + 'float32' + ) + max_data = np.random.uniform(1.5, 2.5, data_shape[-1:]).astype( + 'float32' + ) + out_np = np.clip(data, min_data, max_data) + data = paddle.to_tensor(data) + min_data = paddle.to_tensor(min_data) + max_data = paddle.to_tensor(max_data) + out = paddle.clip(data, min_data, max_data) + np.testing.assert_allclose(out.numpy(), out_np, rtol=1e-05) + + data_shape = [1, 2, 3, 4] + data = np.random.uniform(-2.0, 7.0, data_shape).astype('int32') + min_data = np.random.uniform(-2.5, -1.5, data_shape[-2:]).astype( + 'int32' + ) + max_data = 5 + out_np = np.clip(data, min_data, max_data) + data = paddle.to_tensor(data, dtype='int32') + min_data = paddle.to_tensor(min_data, dtype='int32') + max_data = paddle.to_tensor([max_data], dtype='int32') + out = paddle.clip(data, min_data, max_data) + np.testing.assert_allclose(out.numpy(), out_np, rtol=1e-05) + + data_shape = [1, 2, 3, 4] + data = np.random.uniform(-2.0, 2.0, data_shape).astype('float32') + min_data = np.random.uniform(-2.5, -1.5, data_shape[-2:]).astype( + 'float32' + ) + max_data = float(np.finfo(np.float32).max) + out_np = np.clip(data, min_data, max_data) + data = paddle.to_tensor(data) + min_data = paddle.to_tensor(min_data) + out = paddle.clip(data, min_data) + np.testing.assert_allclose(out.numpy(), out_np, rtol=1e-05) + + data_shape = [1, 2, 3, 4] + data = np.random.uniform(-2.0, 2.0, data_shape).astype('float32') + min_data = np.random.uniform(-2.5, -1.5, data_shape[-2:]).astype( + 'float32' + ) + max_data = 5 + out_np = np.clip(data, min_data, max_data) + data = paddle.to_tensor(data) + min_data = paddle.to_tensor(min_data) + out = paddle.clip(data, min_data, max_data) + np.testing.assert_allclose(out.numpy(), out_np, rtol=1e-05) + + if base.core.is_compiled_with_cuda(): + data_shape = [1, 2, 3, 4] + data = np.random.uniform(-2.0, 2.0, data_shape).astype('float16') + min_data = np.random.uniform(-2.5, -1.5, data_shape[-2:]).astype( + 'float16' + ) + max_data = np.random.uniform(1.5, 2.5, data_shape[-1:]).astype( + 'float16' + ) + out_np = np.clip(data, min_data, max_data) + data = paddle.to_tensor(data) + min_data = paddle.to_tensor(min_data) + max_data = paddle.to_tensor(max_data) + out = paddle.clip(data, min_data, max_data) + np.testing.assert_allclose(out.numpy(), out_np, rtol=1e-05) + + data_shape = [1, 2, 3, 4] + data = np.random.random(data_shape).astype('float16') + min_data = -10.0 + max_data = 10.0 + out_np = np.clip(data, min_data, max_data) + data = paddle.to_tensor(data) + min_data = paddle.to_tensor(min_data) + max_data = paddle.to_tensor(max_data) + out = paddle.clip(data, min_data, max_data) + np.testing.assert_allclose(out.numpy(), out_np, rtol=1e-05) + + paddle.enable_static() + + def test_shapeerror_clip(self): + data_shape = [1, 9, 9, 4] + data = np.random.random(data_shape).astype('float32') + data = paddle.to_tensor(data) + with self.assertRaises(ValueError): + paddle.clip(data, paddle.rand([2])) + + data_shape = [1, 9, 9, 4] + data = np.random.random(data_shape).astype('float32') + data = paddle.to_tensor(data) + with self.assertRaises(ValueError): + paddle.clip( + data, min=paddle.to_tensor([1, 2, 3, 4]), max=paddle.rand([0]) + ) + + def test_tensor_clip_(self): + data_shape = [1, 9, 9, 4] + data = np.random.uniform(-2.0, 2.0, data_shape).astype('float32') + data = paddle.to_tensor(data) + min_data = np.random.uniform(-2.5, -1.5, data_shape[-2:]).astype( + 'float32' + ) + min = paddle.to_tensor(min_data) + max_data = np.random.uniform(1.5, 2.5, data_shape[-3:]).astype( + 'float32' + ) + max = paddle.to_tensor(max_data) + data.clip_(min, max) + + +if __name__ == '__main__': + unittest.main()