diff --git a/paddle/phi/kernels/cpu/clip_tensor_grad_kernel.cc b/paddle/phi/kernels/cpu/clip_tensor_grad_kernel.cc index 64dd11095de4bd..c408e1a95ec68a 100644 --- a/paddle/phi/kernels/cpu/clip_tensor_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/clip_tensor_grad_kernel.cc @@ -27,18 +27,12 @@ void ClipTensorGradKernel(const Context& dev_ctx, const DenseTensor& max, const DenseTensor& out_grad, DenseTensor* x_grad) { - DenseTensor ex_min; - MetaTensor meta_min(&ex_min); - CastInferMeta(min, x.dtype(), &meta_min); - DenseTensor ex_max; - MetaTensor meta_max(&ex_max); - CastInferMeta(max, x.dtype(), &meta_max); - phi::CastKernel(dev_ctx, min, x.dtype(), &ex_min); - phi::CastKernel(dev_ctx, max, x.dtype(), &ex_max); + DenseTensor tem_min = phi::Cast(dev_ctx, min, x.dtype()); + DenseTensor tem_max = phi::Cast(dev_ctx, max, x.dtype()); const T* x_data = x.data(); - const T* min_data = ex_min.data(); - const T* max_data = ex_max.data(); + const T* min_data = tem_min.data(); + const T* max_data = tem_max.data(); auto numel = x.numel(); auto* dout = out_grad.data(); diff --git a/paddle/phi/kernels/cpu/clip_tensor_kernel.cc b/paddle/phi/kernels/cpu/clip_tensor_kernel.cc index 6b3b74fb24b40e..bb46ef891af9fe 100644 --- a/paddle/phi/kernels/cpu/clip_tensor_kernel.cc +++ b/paddle/phi/kernels/cpu/clip_tensor_kernel.cc @@ -27,18 +27,12 @@ void ClipTensorKernel(const Context& dev_ctx, const DenseTensor& min, const DenseTensor& max, DenseTensor* out) { - DenseTensor ex_min; - MetaTensor meta_min(&ex_min); - CastInferMeta(min, x.dtype(), &meta_min); - DenseTensor ex_max; - MetaTensor meta_max(&ex_max); - CastInferMeta(max, x.dtype(), &meta_max); - phi::CastKernel(dev_ctx, min, x.dtype(), &ex_min); - phi::CastKernel(dev_ctx, max, x.dtype(), &ex_max); + DenseTensor tem_min = phi::Cast(dev_ctx, min, x.dtype()); + DenseTensor tem_max = phi::Cast(dev_ctx, max, x.dtype()); const T* x_data = x.data(); - const T* min_data = ex_min.data(); - const T* max_data = ex_max.data(); + const T* min_data = tem_min.data(); + const T* max_data = tem_max.data(); auto x_numel = x.numel(); diff --git a/paddle/phi/kernels/gpu/clip_tensor_grad_kernel.cu b/paddle/phi/kernels/gpu/clip_tensor_grad_kernel.cu index 743d46f819a97b..e8d06a20fae4e6 100644 --- a/paddle/phi/kernels/gpu/clip_tensor_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/clip_tensor_grad_kernel.cu @@ -44,19 +44,13 @@ void ClipTensorGradKernel(const Context& dev_ctx, const DenseTensor& max, const DenseTensor& out_grad, DenseTensor* x_grad) { - DenseTensor ex_min; - MetaTensor meta_min(&ex_min); - CastInferMeta(min, x.dtype(), &meta_min); - DenseTensor ex_max; - MetaTensor meta_max(&ex_max); - CastInferMeta(max, x.dtype(), &meta_max); - phi::CastKernel(dev_ctx, min, x.dtype(), &ex_min); - phi::CastKernel(dev_ctx, max, x.dtype(), &ex_max); + DenseTensor tem_min = phi::Cast(dev_ctx, min, x.dtype()); + DenseTensor tem_max = phi::Cast(dev_ctx, max, x.dtype()); const T* x_data = x.data(); auto numel = x.numel(); - const T* min_data = ex_min.data(); - const T* max_data = ex_max.data(); + const T* min_data = tem_min.data(); + const T* max_data = tem_max.data(); const T* out_grad_data = out_grad.data(); T* x_grad_data = dev_ctx.template Alloc(x_grad); diff --git a/paddle/phi/kernels/gpu/clip_tensor_kernel.cu b/paddle/phi/kernels/gpu/clip_tensor_kernel.cu index b698d87dc32f03..f7e948fd65ec67 100644 --- a/paddle/phi/kernels/gpu/clip_tensor_kernel.cu +++ b/paddle/phi/kernels/gpu/clip_tensor_kernel.cu @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/clip_kernel.h" +#include "paddle/phi/kernels/clip_tensor_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" @@ -27,7 +27,9 @@ namespace phi { template struct ClipTensorFunctor { inline HOSTDEVICE T operator()(const T x, const T min_, const T max_) const { - return x < min_ ? min_ : x > max_ ? max_ : x; + T x_ = x < min_ ? min_ : x; + T x__ = x_ > max_ ? max_ : x_; + return x__; } }; @@ -37,16 +39,10 @@ void ClipTensorKernel(const Context& dev_ctx, const DenseTensor& min, const DenseTensor& max, DenseTensor* out) { - DenseTensor ex_min; - MetaTensor meta_min(&ex_min); - CastInferMeta(min, x.dtype(), &meta_min); - DenseTensor ex_max; - MetaTensor meta_max(&ex_max); - CastInferMeta(max, x.dtype(), &meta_max); - phi::CastKernel(dev_ctx, min, x.dtype(), &ex_min); - phi::CastKernel(dev_ctx, max, x.dtype(), &ex_max); + DenseTensor tem_min = phi::Cast(dev_ctx, min, x.dtype()); + DenseTensor tem_max = phi::Cast(dev_ctx, max, x.dtype()); - std::vector ins = {&x, &ex_min, &ex_max}; + std::vector ins = {&x, &tem_min, &tem_max}; std::vector outs = {out}; dev_ctx.template Alloc(out); diff --git a/paddle/phi/kernels/onednn/clip_tensor_grad_kernel.cc b/paddle/phi/kernels/onednn/clip_tensor_grad_kernel.cc new file mode 100644 index 00000000000000..cfe7931b470e20 --- /dev/null +++ b/paddle/phi/kernels/onednn/clip_tensor_grad_kernel.cc @@ -0,0 +1,202 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/clip_tensor_grad_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/infermeta/unary.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/elementwise_kernel.h" + +namespace phi { +template +void ClipTensorGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& min, + const DenseTensor& max, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + + const auto& onednn_engine = dev_ctx.GetEngine(); + auto& astream = OneDNNContext::tls().get_stream(); + + DenseTensor t_min_mask; + MetaTensor meta_min_mask(&t_min_mask); + UnchangedInferMeta(x, &meta_min_mask); + DenseTensor t_max_mask; + MetaTensor meta_max_mask(&t_max_mask); + UnchangedInferMeta(x, &meta_max_mask); + DenseTensor t_zero_mask; + MetaTensor meta_zero_mask(&t_zero_mask); + UnchangedInferMeta(x, &meta_zero_mask); + + auto* tem_min_mask = &t_min_mask; + auto* tem_max_mask = &t_max_mask; + auto* tem_zero_mask = &t_zero_mask; + auto* non_const_x = &x; + auto* non_const_min = &min; + auto* non_const_max = &max; + auto* non_const_out_grad = &out_grad; + + funcs::BinaryOneDNNHandler Lesshandler(dnnl::algorithm::binary_lt, + -1, + onednn_engine, + dev_ctx.GetPlace(), + non_const_min, + non_const_out_grad, + tem_min_mask, + 1.0f, + 1.0f, + 1.0f, + true); + + auto src_memory_p_min1 = Lesshandler.AcquireSrcMemory(non_const_min); + auto src_memory_p_out_grad1 = + Lesshandler.AcquireSecondSrcMemory(non_const_out_grad); + std::shared_ptr dst_memory_p1 = Lesshandler.AcquireDstMemory(tem_min_mask); + auto activation_p1 = Lesshandler.AcquireForwardPrimitive(); + + std::unordered_map args1 = { + {DNNL_ARG_SRC_0, *src_memory_p_min1}, + {DNNL_ARG_SRC_1, *src_memory_p_out_grad1}, + {DNNL_ARG_DST, *dst_memory_p1}}; + + if (Lesshandler.Has_SRC_0_Scale()) { + args1.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, + Lesshandler.Get_SRC_0_Scale_Memory()}); + } + + if (Lesshandler.Has_SRC_1_Scale()) { + args1.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1, + Lesshandler.Get_SRC_1_Scale_Memory()}); + } + + activation_p1->execute(astream, args1); + + funcs::BinaryOneDNNHandler Grahandler(dnnl::algorithm::binary_gt, + -1, + onednn_engine, + dev_ctx.GetPlace(), + non_const_max, + non_const_out_grad, + tem_max_mask, + 1.0f, + 1.0f, + 1.0f, + true); + + auto src_memory_p_max2 = Grahandler.AcquireSrcMemory(non_const_max); + auto src_memory_p_out_grad2 = + Grahandler.AcquireSecondSrcMemory(non_const_out_grad); + std::shared_ptr dst_memory_p2 = Grahandler.AcquireDstMemory(tem_max_mask); + auto activation_p2 = Grahandler.AcquireForwardPrimitive(); + + std::unordered_map args2 = { + {DNNL_ARG_SRC_0, *src_memory_p_max2}, + {DNNL_ARG_SRC_1, *src_memory_p_out_grad2}, + {DNNL_ARG_DST, *dst_memory_p2}}; + + if (Grahandler.Has_SRC_0_Scale()) { + args2.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, + Grahandler.Get_SRC_0_Scale_Memory()}); + } + + if (Grahandler.Has_SRC_1_Scale()) { + args2.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1, + Grahandler.Get_SRC_1_Scale_Memory()}); + } + + activation_p2->execute(astream, args2); + + funcs::BinaryOneDNNHandler Mulhandler1(dnnl::algorithm::binary_mul, + -1, + onednn_engine, + dev_ctx.GetPlace(), + tem_min_mask, + tem_max_mask, + tem_zero_mask, + 1.0f, + 1.0f, + 1.0f, + true); + + auto src_memory_p_min3 = Mulhandler1.AcquireSrcMemory(tem_min_mask); + auto src_memory_p_max3 = Mulhandler1.AcquireSecondSrcMemory(tem_max_mask); + std::shared_ptr dst_memory_p3 = Mulhandler1.AcquireDstMemory(tem_zero_mask); + auto activation_p3 = Mulhandler1.AcquireForwardPrimitive(); + + std::unordered_map args3 = { + {DNNL_ARG_SRC_0, *src_memory_p_min3}, + {DNNL_ARG_SRC_1, *src_memory_p_max3}, + {DNNL_ARG_DST, *dst_memory_p3}}; + + if (Mulhandler1.Has_SRC_0_Scale()) { + args3.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, + Mulhandler1.Get_SRC_0_Scale_Memory()}); + } + + if (Mulhandler1.Has_SRC_1_Scale()) { + args3.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1, + Mulhandler1.Get_SRC_1_Scale_Memory()}); + } + + activation_p3->execute(astream, args3); + + funcs::BinaryOneDNNHandler Mulhandler2(dnnl::algorithm::binary_mul, + -1, + onednn_engine, + dev_ctx.GetPlace(), + tem_zero_mask, + non_const_x, + x_grad, + 1.0f, + 1.0f, + 1.0f, + true); + + auto src_memory_p_zero4 = Mulhandler2.AcquireSrcMemory(tem_zero_mask); + auto src_memory_p_x4 = Mulhandler2.AcquireSecondSrcMemory(non_const_x); + std::shared_ptr dst_memory_p4 = Mulhandler2.AcquireDstMemory(x_grad); + auto activation_p4 = Mulhandler2.AcquireForwardPrimitive(); + + std::unordered_map args4 = { + {DNNL_ARG_SRC_0, *src_memory_p_zero4}, + {DNNL_ARG_SRC_1, *src_memory_p_x4}, + {DNNL_ARG_DST, *dst_memory_p4}}; + + if (Mulhandler2.Has_SRC_0_Scale()) { + args4.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, + Mulhandler2.Get_SRC_0_Scale_Memory()}); + } + + if (Mulhandler2.Has_SRC_1_Scale()) { + args4.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1, + Mulhandler2.Get_SRC_1_Scale_Memory()}); + } + + activation_p4->execute(astream, args4); + + astream.wait(); + + x_grad->set_mem_desc(dst_memory_p4->get_desc()); +} +} // namespace phi + +PD_REGISTER_KERNEL(clip_tensor_grad, + OneDNN, + ONEDNN, + phi::ClipTensorGradKernel, + float, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/onednn/clip_tensor_kernel.cc b/paddle/phi/kernels/onednn/clip_tensor_kernel.cc new file mode 100644 index 00000000000000..01efa003bb8151 --- /dev/null +++ b/paddle/phi/kernels/onednn/clip_tensor_kernel.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/clip_tensor_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/infermeta/unary.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/elementwise_kernel.h" + +namespace phi { +template +void ClipTensorKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& min, + const DenseTensor& max, + DenseTensor* out) { + + const auto& onednn_engine = dev_ctx.GetEngine(); + auto& astream = OneDNNContext::tls().get_stream(); + + DenseTensor t_out; + MetaTensor meta_out(&t_out); + UnchangedInferMeta(x, &meta_out); + auto* tem_out = &t_out; + auto* non_const_x = &x; + auto* non_const_min = &min; + auto* non_const_max = &max; + + funcs::BinaryOneDNNHandler MAXhandler(dnnl::algorithm::binary_max, + -1, + onednn_engine, + dev_ctx.GetPlace(), + non_const_x, + non_const_min, + tem_out, + 1.0f, + 1.0f, + 1.0f, + true); + + auto src_memory_p_x = MAXhandler.AcquireSrcMemory(non_const_x); + auto src_memory_p_min = MAXhandler.AcquireSecondSrcMemory(non_const_min); + std::shared_ptr dst_memory_p = MAXhandler.AcquireDstMemory(tem_out); + auto activation_p = MAXhandler.AcquireForwardPrimitive(); + + std::unordered_map args = { + {DNNL_ARG_SRC_0, *src_memory_p_x}, + {DNNL_ARG_SRC_1, *src_memory_p_min}, + {DNNL_ARG_DST, *dst_memory_p}}; + + if (MAXhandler.Has_SRC_0_Scale()) { + args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, + MAXhandler.Get_SRC_0_Scale_Memory()}); + } + + if (MAXhandler.Has_SRC_1_Scale()) { + args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1, + MAXhandler.Get_SRC_1_Scale_Memory()}); + } + + activation_p->execute(astream, args); + + funcs::BinaryOneDNNHandler MINhandler(dnnl::algorithm::binary_min, + -1, + onednn_engine, + dev_ctx.GetPlace(), + tem_out, + non_const_max, + out, + 1.0f, + 1.0f, + 1.0f, + true); + + auto src_memory_p_x2 = MINhandler.AcquireSrcMemory(tem_out); + auto src_memory_p_max2 = MINhandler.AcquireSecondSrcMemory(non_const_max); + std::shared_ptr dst_memory_p2 = MINhandler.AcquireDstMemory(out); + auto activation_p2 = MINhandler.AcquireForwardPrimitive(); + + std::unordered_map args2 = { + {DNNL_ARG_SRC_0, *src_memory_p_x2}, + {DNNL_ARG_SRC_1, *src_memory_p_max2}, + {DNNL_ARG_DST, *dst_memory_p2}}; + + if (MINhandler.Has_SRC_0_Scale()) { + args2.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, + MINhandler.Get_SRC_0_Scale_Memory()}); + } + + if (MINhandler.Has_SRC_1_Scale()) { + args2.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1, + MINhandler.Get_SRC_1_Scale_Memory()}); + } + + activation_p2->execute(astream, args2); + + astream.wait(); + + out->set_mem_desc(dst_memory_p2->get_desc()); +} +} // namespace phi + +PD_REGISTER_KERNEL(clip_tensor, + OneDNN, + ONEDNN, + phi::ClipTensorKernel, + float, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/xpu/clip_tensor_grad_kernel.cc b/paddle/phi/kernels/xpu/clip_tensor_grad_kernel.cc new file mode 100644 index 00000000000000..87277f658aab9e --- /dev/null +++ b/paddle/phi/kernels/xpu/clip_tensor_grad_kernel.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/clip_tensor_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/infermeta/unary.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/compare_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/logical_kernel.h" +#include "paddle/phi/kernels/where_kernel.h" + +namespace phi { + +template +void ClipTensorGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& min, + const DenseTensor& max, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + DenseTensor ex_min = phi::Cast(dev_ctx, min, x.dtype()); + DenseTensor ex_max = phi::Cast(dev_ctx, max, x.dtype()); + + phi::DenseTensor x_ls_min; + MetaTensor meta_x_ls_min(&x_ls_min); + UnchangedExceptDtypeInferMeta(x, &meta_x_ls_min); + meta_x_ls_min.set_dtype(phi::DataType::BOOL); + phi::LessThanKernel(dev_ctx, ex_min, x, &x_ls_min); + + phi::DenseTensor x_ls_max; + MetaTensor meta_x_ls_max(&x_ls_max); + UnchangedExceptDtypeInferMeta(x, &meta_x_ls_max); + meta_x_ls_max.set_dtype(phi::DataType::BOOL); + phi::LessThanKernel(dev_ctx, x, ex_max, &x_ls_max); + + phi::DenseTensor out; + MetaTensor meta_out(&out); + UnchangedExceptDtypeInferMeta(x, &meta_out); + meta_out.set_dtype(phi::DataType::BOOL); + phi::LogicalAndKernel(dev_ctx, x_ls_min, x_ls_max, &out); + + phi::DenseTensor zero_tensor; + MetaTensor meta_zero(&zero_tensor); + UnchangedInferMeta(x_grad, &meta_zero); + phi::FullKernel(dev_ctx, + common::vectorize(x_grad->dims()), + 0.0f, + zero_tensor.dtype(), + &zero_tensor); + phi::WhereKernel(dev_ctx, out, out_grad, zero_tensor, x_grad); +} + +} // namespace phi + +PD_REGISTER_KERNEL(clip_tensor_grad, + XPU, + ALL_LAYOUT, + phi::ClipTensorGradKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16, + int64_t, + int) {} diff --git a/paddle/phi/kernels/xpu/clip_tensor_kernel.cc b/paddle/phi/kernels/xpu/clip_tensor_kernel.cc new file mode 100644 index 00000000000000..968bff87258973 --- /dev/null +++ b/paddle/phi/kernels/xpu/clip_tensor_kernel.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/clip_tensor_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/backends/xpu/xpu_header.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/elementwise_kernel.h" + +namespace phi { + +template +void ClipTensorKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& min, + const DenseTensor& max, + DenseTensor* out) { + DenseTensor tem_min = phi::Cast(dev_ctx, min, x.dtype()); + DenseTensor tem_max = phi::Cast(dev_ctx, max, x.dtype()); + + DenseTensor tem_max_out = phi::Maximum(dev_ctx, min, x); + MinimumKernel(dev_ctx, tem_max_out, max, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(clip_tensor, + XPU, + ALL_LAYOUT, + phi::ClipTensorKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16, + int, + int64_t) {} diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index bcfff2a32a9796..d6983057db82b0 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -3823,25 +3823,11 @@ def clip( if paddle.is_tensor(min) else paddle.full_like(x, float(min), x.dtype) ) - check_dtype( - min.dtype, - 'min', - ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], - 'clip_tensor', - '(When the type of min in clip is Variable.)', - ) max = ( max if paddle.is_tensor(max) else paddle.full_like(x, float(max), x.dtype) ) - check_dtype( - max.dtype, - 'max', - ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], - 'clip_tensor', - '(When the type of max in clip is Variable.)', - ) out_shape = get_clip_tensor_shape(x, min, max) x = paddle.broadcast_to(x, out_shape) if x.shape != out_shape else x min = ( @@ -3859,12 +3845,26 @@ def clip( if in_dynamic_or_pir_mode(): return _C_ops.clip_tensor(x, min, max) else: - check_variable_and_dtype( + check_dtype( x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'clip', ) + check_dtype( + min.dtype, + 'min', + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], + 'clip_tensor', + '(When the type of min in clip is Variable.)', + ) + check_dtype( + max.dtype, + 'max', + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], + 'clip_tensor', + '(When the type of max in clip is Variable.)', + ) inputs = {'x': x, 'min': min, 'max': max} helper = LayerHelper('clip_tensor', **locals()) output = helper.create_variable_for_type_inference( @@ -3876,66 +3876,67 @@ def clip( outputs={'out': [output]}, ) return output + if in_dynamic_or_pir_mode(): + if isinstance(min, Variable): + min = min.item(0) + if isinstance(max, Variable): + max = max.item(0) + min = min_ if min is None else min + max = max_ if max is None else max + return _C_ops.clip(x, min, max) else: - if in_dynamic_or_pir_mode(): + if min is not None: + check_type(min, 'min', (float, int, Variable), 'clip') if isinstance(min, Variable): - min = min.item(0) + check_dtype( + min.dtype, + 'min', + ['float16', 'float32', 'float64', 'int32', 'uint16'], + 'clip', + '(When the type of min in clip is Variable.)', + ) + if max is not None: + check_type(max, 'max', (float, int, Variable), 'clip') if isinstance(max, Variable): - max = max.item(0) - return _C_ops.clip(x, min, max) - else: - if min is not None: - check_type(min, 'min', (float, int, Variable), 'clip') - if isinstance(min, Variable): - check_dtype( - min.dtype, - 'min', - ['float16', 'float32', 'float64', 'int32', 'uint16'], - 'clip', - '(When the type of min in clip is Variable.)', - ) - if max is not None: - check_type(max, 'max', (float, int, Variable), 'clip') - if isinstance(max, Variable): - check_dtype( - max.dtype, - 'max', - ['float16', 'float32', 'float64', 'int32', 'uint16'], - 'clip', - '(When the type of max in clip is Variable.)', - ) + check_dtype( + max.dtype, + 'max', + ['float16', 'float32', 'float64', 'int32', 'uint16'], + 'clip', + '(When the type of max in clip is Variable.)', + ) - check_variable_and_dtype( - x, - 'x', - ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], - 'clip', - ) + check_variable_and_dtype( + x, + 'x', + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], + 'clip', + ) - inputs = {'X': x} - attrs = {'min': min_, 'max': max_} + inputs = {'X': x} + attrs = {'min': min_, 'max': max_} - if isinstance(min, Variable): - min.stop_gradient = True - inputs['Min'] = min - elif min is not None: - attrs['min'] = min + if paddle.is_tensor(min): + min.stop_gradient = True + inputs['Min'] = min + elif min is not None: + attrs['min'] = min - if isinstance(max, Variable): - max.stop_gradient = True - inputs['Max'] = max - elif max is not None: - attrs['max'] = max + if paddle.is_tensor(max): + max.stop_gradient = True + inputs['Max'] = max + elif max is not None: + attrs['max'] = max - helper = LayerHelper('clip', **locals()) - output = helper.create_variable_for_type_inference( - dtype=helper.input_dtype('x') - ) - helper.append_op( - type='clip', inputs=inputs, outputs={'Out': [output]}, attrs=attrs - ) + helper = LayerHelper('clip', **locals()) + output = helper.create_variable_for_type_inference( + dtype=helper.input_dtype('x') + ) + helper.append_op( + type='clip', inputs=inputs, outputs={'Out': [output]}, attrs=attrs + ) - return output + return output @inplace_apis_in_dygraph_only diff --git a/test/legacy_test/test_clip_op.py b/test/legacy_test/test_clip_op.py index 8086b565551b1c..4dd53b8036f19b 100644 --- a/test/legacy_test/test_clip_op.py +++ b/test/legacy_test/test_clip_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math import unittest import numpy as np @@ -488,6 +487,7 @@ class TestInplaceClipAPI(TestClipAPI): def _executed_api(self, x, min=None, max=None): return x.clip_(min, max) + class TestClipTensorAPI(unittest.TestCase): def initCase(self): self.x_shape = [10, 10, 1] @@ -529,7 +529,7 @@ def check_dygraph_api(self): out_pd = paddle.clip(x_pd, min, max) np.testing.assert_allclose(self.out_np, out_pd.numpy()) paddle.enable_static() - + def check_static_api(self): if self.dtype == 'float16': return @@ -555,11 +555,13 @@ def check_static_api(self): max_pd = None out_pd = paddle.clip(x_pd, min_pd, max_pd) res = exe.run( - main_program, feed={'x': self.x, 'min': self.min, 'max': self.max}, fetch_list=[out_pd] - ) + main_program, + feed={'x': self.x, 'min': self.min, 'max': self.max}, + fetch_list=[out_pd], + ) np.testing.assert_allclose(self.out_np, res[0]) paddle.disable_static() - + def check_inplace_api(self): if self.dtype == 'float16': return @@ -567,11 +569,10 @@ def check_inplace_api(self): x_pd = paddle.rand(self.x_shape, dtype=self.dtype) min_pd = paddle.rand([self.x_shape[0]], dtype=self.dtype) max_pd = paddle.rand([self.x_shape[0]], dtype=self.dtype) - x_pd.clip_(min_pd, max_pd) out_np = x_pd.numpy().clip(min_pd.numpy(), max_pd.numpy()) + x_pd.clip_(min_pd, max_pd) np.testing.assert_allclose(out_np, x_pd.numpy()) paddle.enable_static() - def test_fp16_api(self): if base.core.is_compiled_with_cuda(): @@ -606,34 +607,9 @@ def test_fp16_api(self): }, fetch_list=[out_pd], ) - np.testing.assert_allclose(self.out_np, res[0]) paddle.disable_static() -class TestClipTensorCase1(TestClipTensorAPI): - def initCase(self): - self.x_shape = [10, 10, 1] - self.min_shape = [1] - self.max_shape = [1] - self.dtype = 'float32' - - -class TestClipTensorCase2(TestClipTensorAPI): - def initCase(self): - self.x_shape = [10, 10, 1] - self.min_shape = [1] - self.max_shape = [1] - self.dtype = 'float16' - - -class TestClipTensorCase3(TestClipTensorAPI): - def initCase(self): - self.x_shape = [10, 10, 1] - self.min_shape = [1] - self.max_shape = [1] - self.dtype = 'float64' - - class TestClipTensorCase4(TestClipTensorAPI): def initCase(self): self.x_shape = [10, 1, 10] @@ -677,7 +653,7 @@ def initCase(self): class TestClipTensorCase9(TestClipTensorAPI): def initCase(self): self.x_shape = [10, 1, 10] - self.min_shape =None + self.min_shape = None self.max_shape = [10] self.dtype = 'float16' @@ -770,5 +746,186 @@ def initCase(self): self.max_shape = [10, 1, 10] +class TestClipTensorOp(OpTest): + def setUp(self): + self.max_relative_error = 0.006 + self.op_type = "clip_tensor" + self.python_api = paddle.clip + + self.inputs = {} + self.initTestCase() + input = np.random.random(self.shape).astype(self.dtype) + min_v = np.full(self.shape, self.min_value).astype(self.dtype) + max_v = np.full(self.shape, self.max_value).astype(self.dtype) + + input[np.abs(input - min_v) < self.max_relative_error] = 0.5 + input[np.abs(input - max_v) < self.max_relative_error] = 0.5 + + self.inputs['min'] = min_v + self.inputs['max'] = max_v + self.inputs['x'] = input + self.outputs = {'out': np.clip(input, min_v, max_v)} + + def test_check_output(self): + paddle.enable_static() + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + paddle.enable_static() + self.check_grad(['x'], 'out', check_pir=True) + + def initTestCase(self): + self.dtype = np.float32 + self.shape = (8, 5, 6) + self.min_value = 0.8 + self.max_value = 0.3 + + +class TestClipTensorOpCase1(TestClipTensorOp): + def initTestCase(self): + self.dtype = np.float32 + self.shape = (5, 6, 8) + self.max_value = 0.7 + self.min_value = 0.0 + + +class TestClipTensorOpCase2(TestClipTensorOp): + def initTestCase(self): + self.dtype = np.float32 + self.shape = (8, 5, 6) + self.max_value = 1.0 + self.min_value = 0.0 + + +class TestClipTensorOpCase3(TestClipTensorOp): + def initTestCase(self): + self.dtype = np.float32 + self.shape = (4, 8, 6) + self.max_value = 0.7 + self.min_value = 0.2 + + +class TestClipTensorOpFP16Case1(TestClipTensorOp): + def initTestCase(self): + self.dtype = np.float16 + self.shape = (5, 6, 8) + self.max_value = 0.7 + self.min_value = 0.0 + + +class TestClipTensorOpFP16Case2(TestClipTensorOp): + def initTestCase(self): + self.dtype = np.float16 + self.shape = (8, 5, 6) + self.max_value = 1.0 + self.min_value = 0.0 + + +class TestClipTensorOpFP16Case3(TestClipTensorOp): + def initTestCase(self): + self.dtype = np.float16 + self.shape = (5, 8, 6) + self.max_value = 0.7 + self.min_value = 0.2 + + +class TestClipTensorOpFP64Case1(TestClipTensorOp): + def initTestCase(self): + self.dtype = np.float64 + self.shape = (8, 6, 5) + self.max_value = 0.7 + self.min_value = 0.0 + + +class TestClipTensorOpFP64Case2(TestClipTensorOp): + def initTestCase(self): + self.dtype = np.float64 + self.shape = (8, 5, 6) + self.max_value = 1.0 + self.min_value = 0.0 + + +class TestClipTensorOpFP64Case3(TestClipTensorOp): + def initTestCase(self): + self.dtype = np.float64 + self.shape = (4, 8, 6) + self.max_value = 0.7 + self.min_value = 0.2 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the bfloat16", +) +class TestClipTensorBF16Op(OpTest): + def setUp(self): + self.max_relative_error = 0.006 + self.op_type = "clip_tensor" + self.python_api = paddle.clip + self.inputs = {} + self.initTestCase() + + self.inputs['x'] = np.random.random(self.shape).astype(np.float32) + self.inputs['min'] = np.full(self.shape, self.min_value).astype( + np.float32 + ) + self.inputs['max'] = np.full(self.shape, self.max_value).astype( + np.float32 + ) + min_v = self.inputs['min'] + max_v = self.inputs['max'] + + self.inputs['x'][ + np.abs(self.inputs['x'] - min_v) < self.max_relative_error + ] = 0.5 + self.inputs['x'][ + np.abs(self.inputs['x'] - max_v) < self.max_relative_error + ] = 0.5 + + self.inputs['x'] = convert_float_to_uint16(self.inputs['x']) + self.inputs['min'] = convert_float_to_uint16(self.inputs['min']) + self.inputs['max'] = convert_float_to_uint16(self.inputs['max']) + out = np.clip(self.inputs['x'], min_v, max_v) + + self.outputs = {'out': convert_float_to_uint16(out)} + + def test_check_output(self): + place = paddle.CUDAPlace(0) + paddle.enable_static() + self.check_output_with_place(place) + + def test_check_grad_normal(self): + place = paddle.CUDAPlace(0) + paddle.enable_static() + self.check_grad_with_place(place, ['x'], 'out') + + def initTestCase(self): + self.shape = (8, 5, 6) + self.min_value = 0.8 + self.max_value = 0.3 + + +class TestClipTensorOBF16Case1(TestClipTensorBF16Op): + def initTestCase(self): + self.shape = (8, 6, 5) + self.max_value = 0.7 + self.min_value = 0.0 + + +class TestClipTensorOpBF16Case2(TestClipTensorBF16Op): + def initTestCase(self): + self.shape = (5, 8, 6) + self.max_value = 1.0 + self.min_value = 0.0 + + +class TestClipTensorOpBF16Case3(TestClipTensorBF16Op): + def initTestCase(self): + self.shape = (4, 8, 7) + self.max_value = 0.7 + self.min_value = 0.2 + + if __name__ == '__main__': unittest.main() diff --git a/test/mkldnn/test_clip_tensor_mkldnn_op.py b/test/mkldnn/test_clip_tensor_mkldnn_op.py new file mode 100644 index 00000000000000..da14898158a975 --- /dev/null +++ b/test/mkldnn/test_clip_tensor_mkldnn_op.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest, convert_float_to_uint16 + +import paddle +from paddle.base import core + + +@unittest.skipIf( + not core.supports_bfloat16(), "place does not support BF16 evaluation" +) +class TestClipTensorBf16(OpTest): + def setUp(self): + self.max_relative_error = 0.006 + self.op_type = "clip_tensor" + self.init_dtype() + self.initTestCase() + self.inputs = {} + self.inputs['x'] = np.random.random(self.shape).astype(self.dtype) + self.inputs['min'] = np.full(self.shape, self.min_value).astype( + self.dtype + ) + self.inputs['max'] = np.full(self.shape, self.max_value).astype( + self.dtype + ) + min_v = self.inputs['min'] + max_v = self.inputs['max'] + + self.inputs['x'][ + np.abs(self.inputs['x'] - min_v) < self.max_relative_error + ] = 0.5 + self.inputs['x'][ + np.abs(self.inputs['x'] - max_v) < self.max_relative_error + ] = 0.5 + self.out = np.clip(self.inputs['x'], min_v, max_v) + + self.x_bf16 = convert_float_to_uint16(self.inputs['x']) + self.min_bf16 = convert_float_to_uint16(self.inputs['min']) + self.max_bf16 = convert_float_to_uint16(self.inputs['max']) + self.out_bf16 = convert_float_to_uint16(self.out) + self.inputs = { + 'x': self.x_bf16, + 'min': self.min_bf16, + 'max': self.max_bf16, + } + self.attrs = { + 'use_mkldnn': True, + 'mkldnn_data_type': self.mkldnn_data_type, + } + self.outputs = {'out': self.out_bf16} + + def init_dtype(self): + self.dtype = np.float32 + self.mkldnn_data_type = "bfloat16" + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace(), check_pir_onednn=True) + + def test_check_grad(self): + self.check_grad_with_place( + core.CPUPlace(), ["x"], "out", check_pir_onednn=True + ) + + def initTestCase(self): + self.shape = (10, 1, 10) + self.min_value = 0.8 + self.max_value = 0.3 + + +class TestBf16Case1(TestClipTensorBf16): + def initTestCase(self): + self.shape = (8, 6, 8) + self.max_value = 0.7 + self.min_value = 0.0 + + +class TestBf16Case2(TestClipTensorBf16): + def initTestCase(self): + self.shape = (8, 8, 6) + self.max_value = 1.0 + self.min_value = 0.0 + + +class TestBf16Case3(TestClipTensorBf16): + def initTestCase(self): + self.shape = (4, 8, 6) + self.max_value = 0.7 + self.min_value = 0.2 + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/test/white_list/op_accuracy_white_list.py b/test/white_list/op_accuracy_white_list.py index 5f6e8ee790fc28..da8fcf6767ddd7 100644 --- a/test/white_list/op_accuracy_white_list.py +++ b/test/white_list/op_accuracy_white_list.py @@ -17,6 +17,7 @@ 'instance_norm', 'affine_grid', 'clip', + 'clip_tensor', 'conv2d', 'conv2d_transpose', 'conv3d', diff --git a/test/xpu/test_clip_op_xpu.py b/test/xpu/test_clip_op_xpu.py index 2c9229f2afbec4..67dc8bddb11b9d 100644 --- a/test/xpu/test_clip_op_xpu.py +++ b/test/xpu/test_clip_op_xpu.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -271,5 +271,296 @@ def _executed_api(self, x, min=None, max=None): continue create_test_class(globals(), XPUTestClipOp, stype) + +class TestClipTensorAPI(unittest.TestCase): + def initCase(self): + self.x_shape = [10, 10, 1] + self.min_shape = [10] + self.max_shape = [10] + self.dtype = 'float32' + + def setUp(self): + self.initCase() + self.place = ( + base.XPUPlace(0) + if base.core.is_compiled_with_xpu() + else base.CPUPlace() + ) + self.x = np.random.random(self.x_shape).astype(self.dtype) + if self.min_shape is None: + self.min = None + else: + self.min = np.random.random(self.min_shape).astype(self.dtype) + if self.max_shape is None: + self.max = None + else: + self.max = np.random.random(self.max_shape).astype(self.dtype) + self.out_np = self.x.clip(self.min, self.max) + + def check_dygraph_api(self): + paddle.disable_static(self.place) + x_pd = paddle.to_tensor(self.x) + if self.min is None: + min = None + else: + min = paddle.to_tensor(self.min) + if self.max is None: + max = None + else: + max = paddle.to_tensor(self.max) + out_pd = paddle.clip(x_pd, min, max) + np.testing.assert_allclose(self.out_np, out_pd.numpy()) + paddle.enable_static() + + def check_static_api(self): + paddle.enable_static() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + exe = paddle.static.Executor(self.place) + with paddle.static.program_guard(main_program, startup_program): + x_pd = paddle.static.data( + name='x', shape=self.x_shape, dtype=self.dtype + ) + if self.min is not None: + min_pd = paddle.static.data( + name='min', shape=self.min_shape, dtype=self.dtype + ) + else: + min_pd = None + if self.max is not None: + max_pd = paddle.static.data( + name='max', shape=self.max_shape, dtype=self.dtype + ) + else: + max_pd = None + out_pd = paddle.clip(x_pd, min_pd, max_pd) + res = exe.run( + main_program, + feed={'x': self.x, 'min': self.min, 'max': self.max}, + fetch_list=[out_pd], + ) + np.testing.assert_allclose(self.out_np, res[0]) + paddle.disable_static() + + def check_inplace_api(self): + paddle.disable_static(self.place) + x_pd = paddle.rand(self.x_shape, dtype=self.dtype) + min_pd = paddle.rand([self.x_shape[0]], dtype=self.dtype) + max_pd = paddle.rand([self.x_shape[0]], dtype=self.dtype) + out_np = x_pd.numpy().clip(min_pd.numpy(), max_pd.numpy()) + x_pd.clip_(min_pd, max_pd) + np.testing.assert_allclose(out_np, x_pd.numpy()) + paddle.enable_static() + + +class TestClipTensorCase1(TestClipTensorAPI): + def initCase(self): + self.x_shape = [10, 1, 10] + self.min_shape = [10] + self.max_shape = [10] + self.dtype = 'float32' + + + +class TestClipTensorCase2(TestClipTensorAPI): + def initCase(self): + self.x_shape = [10, 1, 10] + self.min_shape = None + self.max_shape = [10] + self.dtype = 'float32' + + +class TestClipTensorCase3(TestClipTensorAPI): + def initCase(self): + self.x_shape = [10, 1, 10] + self.min_shape = [10] + self.max_shape = None + self.dtype = 'float32' + + +class TestClipTensorCase4(TestClipTensorAPI): + def initCase(self): + self.dtype = 'int32' + self.x_shape = [10, 1, 10] + self.min_shape = [10] + self.max_shape = [10] + + +class TestClipTensorCase5(TestClipTensorAPI): + def initCase(self): + self.dtype = 'int64' + self.x_shape = [10, 1, 10] + self.min_shape = [10] + self.max_shape = [10] + + +class TestClipTensorCase6(TestClipTensorAPI): + def initCase(self): + self.dtype = 'int32' + self.x_shape = [10, 1, 10] + self.min_shape = None + self.max_shape = [10] + + +class TestClipTensorCase7(TestClipTensorAPI): + def initCase(self): + self.dtype = 'int64' + self.x_shape = [10, 1, 10] + self.min_shape = None + self.max_shape = [10] + + +class TestClipTensorCase8(TestClipTensorAPI): + def initCase(self): + self.dtype = 'int32' + self.x_shape = [10, 1, 10] + self.min_shape = [10] + self.max_shape = None + + +class TestClipTensorCase9(TestClipTensorAPI): + def initCase(self): + self.dtype = 'int64' + self.x_shape = [10, 1, 10] + self.min_shape = [10] + self.max_shape = None + + +class TestClipTensorCase10(TestClipTensorAPI): + def initCase(self): + self.dtype = 'float32' + self.x_shape = [10] + self.min_shape = [10, 1, 10] + self.max_shape = [10] + + +class TestClipTensorCase11(TestClipTensorAPI): + def initCase(self): + self.dtype = 'float32' + self.x_shape = [10] + self.min_shape = [10] + self.max_shape = [10, 1, 10] + + +class XPUTestClipTensorOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'clip_tensor' + self.use_dynamic_create_class = False + + class ClipTensorOp(XPUOpTest): + def setUp(self): + self.python_api = paddle.clip + self.inputs = {} + self.init_dtype() + self.set_xpu() + self.op_type = "clip_tensor" + self.place = paddle.XPUPlace(0) + self.init_data() + self.set_inputs() + if self.dtype == np.uint16: + self.outputs = { + 'out': convert_float_to_uint16( + np.clip( + convert_uint16_to_float(self.inputs['x']), + convert_uint16_to_float(self.inputs['min']), + convert_uint16_to_float(self.inputs['max']), + ) + ) + } + else: + self.outputs = { + 'out': np.clip( + self.inputs['x'], + self.inputs['min'], + self.inputs['max'], + ) + } + + def set_xpu(self): + self.__class__.use_xpu = True + self.__class__.no_need_check_grad = False + self.__class__.op_type = self.dtype + + def init_data(self): + self.shape = (10, 1, 10) + self.min_value = 0.8 + self.max_value = 0.3 + + def set_inputs(self): + self.inputs['x'] = np.random.random(self.shape).astype("float32") + self.inputs['min'] = np.full(self.shape, self.min_value).astype( + 'float32' + ) + self.inputs['max'] = np.full(self.shape, self.max_value).astype( + 'float32' + ) + + self.min_v = self.inputs['min'] + self.max_v = self.inputs['max'] + + self.max_relative_error = 0.006 + self.inputs['x'][ + np.abs(self.inputs['x'] - self.min_v) < self.max_relative_error + ] = 0.5 + self.inputs['x'][ + np.abs(self.inputs['x'] - self.max_v) < self.max_relative_error + ] = 0.5 + if self.dtype == np.uint16: + self.inputs['x'] = convert_float_to_uint16(self.inputs['x']) + self.inputs['min'] = convert_float_to_uint16(self.inputs['min']) + self.inputs['max'] = convert_float_to_uint16(self.inputs['max']) + else: + self.inputs['x'] = self.inputs['x'].astype(self.dtype) + self.inputs['min'] = self.inputs['min'].astype(self.dtype) + self.inputs['max'] = self.inputs['max'].astype(self.dtype) + + def init_dtype(self): + self.dtype = self.in_type + + def test_check_output(self): + paddle.enable_static() + self.check_output_with_place(self.place) + paddle.disable_static() + + def test_check_grad(self): + if hasattr(self, "no_need_check_grad") and self.no_need_check_grad: + return + if core.is_compiled_with_xpu(): + paddle.enable_static() + self.check_grad_with_place(self.place, ['x'], 'out') + paddle.disable_static() + + class TestClipTensorOp1(ClipTensorOp): + def init_data(self): + self.shape = (8, 6, 8) + self.max_value = 0.7 + self.min_value = 0.0 + + class TestClipTensorOp2(ClipTensorOp): + def init_data(self): + self.shape = (8, 8, 6) + self.max_value = 1.0 + self.min_value = 0.0 + + class TestClipTensorOp3(ClipTensorOp): + def init_data(self): + self.shape = (4, 8, 6) + self.max_value = 0.7 + self.min_value = 0.2 + + class TestClipTensorOp4(ClipTensorOp): + def init_data(self): + self.shape = (4, 8, 6) + self.max_value = 0.5 + self.min_value = 0.5 + + +support_types = get_xpu_op_support_types('clip_tensor') +for stype in support_types: + # TODO: disable int32 and int64 test temporarily, as xdnn not support corresponding resuce_mean + if stype in ["int32", "int64"]: + continue + create_test_class(globals(), XPUTestClipTensorOp, stype) + if __name__ == '__main__': unittest.main()