diff --git a/paddle/phi/kernels/clip_grad_kernel.h b/paddle/phi/kernels/clip_grad_kernel.h index 2756cf70cc02eb..450e782438e2e6 100644 --- a/paddle/phi/kernels/clip_grad_kernel.h +++ b/paddle/phi/kernels/clip_grad_kernel.h @@ -31,8 +31,8 @@ void ClipGradKernel(const Context& dev_ctx, template void ClipWithTensorGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& out_grad, const DenseTensor& min, const DenseTensor& max, + const DenseTensor& out_grad, DenseTensor* x_grad); } // namespace phi diff --git a/paddle/phi/kernels/cpu/clip_grad_kernel.cc b/paddle/phi/kernels/cpu/clip_grad_kernel.cc index 78d1eb17e77964..f4a3bf6e69a100 100644 --- a/paddle/phi/kernels/cpu/clip_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/clip_grad_kernel.cc @@ -21,12 +21,11 @@ namespace phi { template -void ClipWithTensorGradKernel(const Context& ctx, - const Context& dev_ctx, +void ClipWithTensorGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& out_grad, const DenseTensor& min, const DenseTensor& max, + const DenseTensor& out_grad, DenseTensor* x_grad) { const T* x_data = x.data(); const T* min_data = min.data(); @@ -34,7 +33,7 @@ void ClipWithTensorGradKernel(const Context& ctx, auto numel = x.numel(); auto* dout = out_grad.data(); - auto* dx = ctx.template Alloc(x_grad); + auto* dx = dev_ctx.template Alloc(x_grad); for (int i = 0; i < numel; i++) { dx[i] = (x_data[i] > min_data[i] && x_data[i] < max_data[i]) ? dout[i] : static_cast(0); } @@ -58,4 +57,4 @@ PD_REGISTER_KERNEL(clipwithtensor_grad, float, double, int, - int64_t) {} \ No newline at end of file + int64_t) {} diff --git a/paddle/phi/kernels/cpu/clip_kernel.cc b/paddle/phi/kernels/cpu/clip_kernel.cc index 0bd2c72b6bd8bc..96f868166f180f 100644 --- a/paddle/phi/kernels/cpu/clip_kernel.cc +++ b/paddle/phi/kernels/cpu/clip_kernel.cc @@ -21,24 +21,20 @@ namespace phi { template -void ClipWithTensorKernel(const Context& ctx, +void ClipWithTensorKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& min, const DenseTensor& max, DenseTensor* out) { - const T* x_data = x.data(); + const T* x_data = x.data(); const T* min_data = min.data(); const T* max_data = max.data(); auto x_numel = x.numel(); - T* out_data = ctx.template Alloc(out); + T* out_data = dev_ctx.template Alloc(out); for (int i = 0; i < x_numel; i++) { - PADDLE_ENFORCE_LE( - min_data[i], - max_data[i], - errors::InvalidArgument("max should be greater than or equal to min. ")); - out_data[i] = x_data[i] < min_data[i] ? min_data[i] : x_data[i] > max_data[i] ? max_data[i] : x; + out_data[i] = x_data[i] < min_data[i] ? min_data[i] : x_data[i] > max_data[i] ? max_data[i] : x_data[i]; } } @@ -48,4 +44,4 @@ PD_REGISTER_KERNEL( clip, CPU, ALL_LAYOUT, phi::ClipKernel, float, double, int, int64_t) {} PD_REGISTER_KERNEL( - clipwithtensor, CPU, ALL_LAYOUT, phi::ClipWithTensorKernel, float, double, int, int64_t) {} \ No newline at end of file + clipwithtensor, CPU, ALL_LAYOUT, phi::ClipWithTensorKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/gpu/clip_grad_kernel.cu b/paddle/phi/kernels/gpu/clip_grad_kernel.cu index 7ca78c631e3156..3cd032b196893a 100644 --- a/paddle/phi/kernels/gpu/clip_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/clip_grad_kernel.cu @@ -23,25 +23,33 @@ namespace phi { template -class ClipWithTensorGradFunctor { - HOSTDEVICE T operator()(const T x, const T y, const T min_, const max_) const { - return (y > min_ && y < max_) ? x : static_cast(0); +__global__ void ClipWithTensorGradFunctor(const int N, const T* out_grad, const T* x, const T* min, const T* max, T* x_grad) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < N; idx += blockDim.x * gridDim.x) { + x_grad[idx] = (x[idx] > min[idx]) && (x[idx] < max[idx]) ? out_grad[idx] : static_cast(0); } }; template void ClipWithTensorGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& out_grad, const DenseTensor& min, const DenseTensor& max, + const DenseTensor& out_grad, DenseTensor* x_grad) { - std::vector ins = {&out_grad, &x, &min, &max}; - std::vector outs = {x_grad}; - auto functor = ClipWithTensorGradFunctor(); - dev_ctx.template Alloc(x_grad); - phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); + const T* x_data = x.data(); + auto numel = x.numel(); + const T* min_data = min.data(); + const T* max_data = max.data(); + const T* out_grad_data = out_grad.data(); + + T* x_grad_data = dev_ctx.template Alloc(x_grad); + + auto stream = dev_ctx.stream(); + auto config = backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); + ClipWithTensorGradFunctor<<>>( + numel, out_grad_data, x_data, min_data, max_data, x_grad_data); } } @@ -65,4 +73,4 @@ PD_REGISTER_KERNEL(clipwithtensor_grad, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} \ No newline at end of file + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/clip_kernel.cu b/paddle/phi/kernels/gpu/clip_kernel.cu index b64cbd22d55a2e..690848c2759996 100644 --- a/paddle/phi/kernels/gpu/clip_kernel.cu +++ b/paddle/phi/kernels/gpu/clip_kernel.cu @@ -14,33 +14,35 @@ #include "paddle/phi/kernels/clip_kernel.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/clip_kernel_impl.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" namespace phi { template struct ClipWithTensorFunctor { - inline HOSTDEVICE T operator()(const bool x, const T min_, const T max_) const { - return x < min_ ? min_ : x > max_ ? max_ : x; + inline HOSTDEVICE T operator()(const T x, const T min_, const T max_) const { + return x < min_ ? min_ : (x > max_ ? max_ : x); } }; template -void ClipWithTensorKernel(const Context& ctx, +void ClipWithTensorKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& min, const DenseTensor& max, DenseTensor* out) { std::vector ins = {&x, &min, &max}; std::vector outs = {out}; - ctx.template Alloc(out); + dev_ctx.template Alloc(out); ClipWithTensorFunctor func; - funcs::ElementwiseKernel, 1>(ctx, ins, &outs, func); + funcs::ElementwiseKernel, 1>(dev_ctx, ins, &outs, func); } } // namespace phi @@ -65,4 +67,4 @@ PD_REGISTER_KERNEL(clipwithtensor, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} \ No newline at end of file + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/xpu/clip_grad_kernel.cc b/paddle/phi/kernels/xpu/clip_grad_kernel.cc index 5e1e7812e74895..0a8a523e0bd734 100644 --- a/paddle/phi/kernels/xpu/clip_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/clip_grad_kernel.cc @@ -14,8 +14,13 @@ #include "paddle/phi/kernels/clip_grad_kernel.h" +#include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_header.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/compare_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/where_kernel.h" namespace phi { @@ -38,6 +43,27 @@ void ClipGradKernel(const Context& ctx, static_cast(max.to())); PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_grad"); } + +template +void ClipWithTensorGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& min, + const DenseTensor& max, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + dev_ctx.template Alloc(x_grad); + using XPUDataType = typename XPUTypeTrait::Type; + + DenseTensor min_tensor(phi::DataType::BOOL); + DenseTensor max_tensor(phi::DataType::BOOL); + LessThanKernel(dev_ctx, min, x, &min_tensor); + LessThanKernel(dev_ctx, x, max, &max_tensor); + DenseTensor out(phi::DataType::BOOL); + EqualKernel(dev_ctx, min_tensor, max_tensor, &out); + DenseTensor zero_tensor(x_grad->dtype()); + FullKernel(dev_ctx, common::vectorize(x_grad->dims()), 0.0f, zero_tensor.dtype(), &zero_tensor); + WhereKernel(dev_ctx, out, out_grad, zero_tensor, x_grad); +} } // namespace phi PD_REGISTER_KERNEL(clip_grad, @@ -48,3 +74,12 @@ PD_REGISTER_KERNEL(clip_grad, phi::dtype::float16, int64_t, int) {} + +PD_REGISTER_KERNEL(clipwithtensor_grad, + XPU, + ALL_LAYOUT, + phi::ClipWithTensorGradKernel, + float, + phi::dtype::float16, + int64_t, + int) {} diff --git a/paddle/phi/kernels/xpu/clip_kernel.cc b/paddle/phi/kernels/xpu/clip_kernel.cc index 827882c1eb84b4..cd6d3c58dfb4ba 100644 --- a/paddle/phi/kernels/xpu/clip_kernel.cc +++ b/paddle/phi/kernels/xpu/clip_kernel.cc @@ -17,8 +17,11 @@ #include "glog/logging.h" #include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_header.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/compare_kernel.h" +#include "paddle/phi/kernels/where_kernel.h" namespace phi { @@ -47,6 +50,56 @@ void ClipKernel(const Context& dev_ctx, XPUAPIErrorMsg[r])); } +template +void ClipWithTensorKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& min, + const DenseTensor& max, + DenseTensor* out) { + using XPUDataType = typename XPUTypeTrait::Type; + const XPUDataType* x_data = reinterpret_cast(x.data()); + const XPUDataType* min_data = reinterpret_cast(min.data()); + const XPUDataType* max_data = reinterpret_cast(max.data()); + XPUDataType* out_data = reinterpret_cast(dev_ctx.template Alloc(out)); + + auto min_dims = common::vectorize(min.dims()); + if (min_dims.size() == 0) { + min_dims = std::vector({1}); + } + auto max_dims = common::vectorize(max.dims()); + if (max_dims.size() == 0) { + max_dims = std::vector({1}); + } + + DenseTensor min_tensor(phi::DataType::BOOL); + LessThanKernel(dev_ctx, x, min, &min_tensor); + + auto min_tensor_dims = common::vectorize(min_tensor.dims()); + if (min_tensor_dims.size() == 0) { + min_tensor_dims = std::vector({1}); + } + + const bool* min_tensor_data = min_tensor.data(); + int ret = xpu::select( + dev_ctx.x_context(), min_tensor_data, min_data, x_data, out_data, min_tensor_dims, min_dims); + + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu::select"); + + DenseTensor max_tensor(phi::DataType::BOOL); + LessThanKernel(dev_ctx, max, x, &max_tensor); + + auto max_tensor_dims = common::vectorize(max_tensor.dims()); + if (max_tensor_dims.size() == 0) { + max_tensor_dims = std::vector({1}); + } + + const bool* max_tensor_data = max_tensor.data(); + int ret2 = xpu::select( + dev_ctx.x_context(), max_tensor_data, max_data, x_data, out_data, max_tensor_dims, max_dims); + PADDLE_ENFORCE_XDNN_SUCCESS(ret2, "xpu::select"); + +} + } // namespace phi PD_REGISTER_KERNEL(clip, @@ -58,3 +111,13 @@ PD_REGISTER_KERNEL(clip, phi::dtype::bfloat16, int64_t, int) {} + +PD_REGISTER_KERNEL(clipwithtensor, + XPU, + ALL_LAYOUT, + phi::ClipWithTensorKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16, + int64_t, + int) {} \ No newline at end of file diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 093cf79f703bda..aaf655913a8995 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -403,8 +403,8 @@ inplace : (out_grad -> x_grad) - backward_op : clipwithtensor_double_grad - forward : clipwithtensor_grad (Tensor x, Tensor grad_out, Tensor min, Tensor max) -> Tensor(grad_x) - args : (Tensor x, Tensor grad_x_grad, Tensor min, Tensor max) + forward : clipwithtensor_grad (Tensor x, Tensor min, Tensor max, Tensor grad_out) -> Tensor(grad_x) + args : (Tensor x, Tensor min, Tensor max, Tensor grad_x_grad) output : Tensor(grad_out_grad) infer_meta : func : UnchangedInferMeta @@ -415,7 +415,7 @@ - backward_op : clipwithtensor_grad forward : clipwithtensor (Tensor x, Tensor min, Tensor max) -> Tensor(out) - args : (Tensor x, Tensor out_grad, Tensor min, Tensor) + args : (Tensor x, Tensor min, Tensor max, Tensor out_grad) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index 1c30eaf81b8933..3be3937ed1ead1 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -602,20 +602,18 @@ extra : attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"] -- op : clipwithtensor - backward : clipwithtensor_grad, clipwithtensor_double_grad +- op : clip_by_norm inputs : x : X - min : Min - max : Max outputs : out : Out - extra : - attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"] -- op : clip_by_norm +- op : clipwithtensor + backward : clipwithtensor_grad, clipwithtensor_double_grad inputs : x : X + min : Min + max : Max outputs : out : Out diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 20f4d0ee300d1f..ec12fbf13f7413 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -963,6 +963,17 @@ backward : clip_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : clip_by_norm + args : (Tensor x, float max_norm) + output : Tensor(out) + infer_meta : + func : ClipByNormInferMeta + kernel : + func : clip_by_norm {dense -> dense} + clip_by_norm_sr {selected_rows -> selected_rows} + interfaces : paddle::dialect::InferSymbolicShapeInterface + traits : paddle::dialect::ForwardOnlyTrait + - op : clipwithtensor args : (Tensor x, Tensor min, Tensor max) output : Tensor(out) @@ -976,17 +987,6 @@ backward : clipwithtensor_grad interfaces : paddle::dialect::InferSymbolicShapeInterface -- op : clip_by_norm - args : (Tensor x, float max_norm) - output : Tensor(out) - infer_meta : - func : ClipByNormInferMeta - kernel : - func : clip_by_norm {dense -> dense} - clip_by_norm_sr {selected_rows -> selected_rows} - interfaces : paddle::dialect::InferSymbolicShapeInterface - traits : paddle::dialect::ForwardOnlyTrait - - op : coalesce_tensor args : (Tensor[] input, DataType dtype, bool copy_data = false, bool set_constant = false, bool persist_output = false, float constant = 0.0, bool use_align = true, int align_size = -1, int size_of_dtype = -1, int64_t[] concated_shapes = {}, int64_t[] concated_ranks = {}) output : Tensor[](output){input.size()}, Tensor(fused_output) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index c2e96dad037b00..cf991456fc0da8 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -3831,7 +3831,7 @@ def clip( raise ValueError( f"The max dimension should be equal to the inner dimension of the x, but the max dimension is {max.shape}" ) - + if in_dynamic_or_pir_mode(): return _C_ops.clipwithtensor(x, min, max) else: