Skip to content

Commit

Permalink
delete min compare with max
Browse files Browse the repository at this point in the history
  • Loading branch information
a162837 committed Nov 12, 2024
1 parent fe7a239 commit ade40cc
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 53 deletions.
2 changes: 1 addition & 1 deletion paddle/phi/kernels/clip_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ void ClipGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void ClipWithTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& min,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad);
} // namespace phi
9 changes: 4 additions & 5 deletions paddle/phi/kernels/cpu/clip_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,19 @@
namespace phi {

template <typename T, typename Context>
void ClipWithTensorGradKernel(const Context& ctx,
const Context& dev_ctx,
void ClipWithTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& min,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
const T* x_data = x.data<T>();
const T* min_data = min.data<T>();
const T* max_data = max.data<T>();
auto numel = x.numel();
auto* dout = out_grad.data<T>();

auto* dx = ctx.template Alloc<T>(x_grad);
auto* dx = dev_ctx.template Alloc<T>(x_grad);
for (int i = 0; i < numel; i++) {
dx[i] = (x_data[i] > min_data[i] && x_data[i] < max_data[i]) ? dout[i] : static_cast<T>(0);
}
Expand All @@ -58,4 +57,4 @@ PD_REGISTER_KERNEL(clipwithtensor_grad,
float,
double,
int,
int64_t) {}
int64_t) {}
14 changes: 5 additions & 9 deletions paddle/phi/kernels/cpu/clip_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,20 @@
namespace phi {

template <typename T, typename Context>
void ClipWithTensorKernel(const Context& ctx,
void ClipWithTensorKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out) {
const T* x_data = x.data<bool>();
const T* x_data = x.data<T>();
const T* min_data = min.data<T>();
const T* max_data = max.data<T>();
auto x_numel = x.numel();

T* out_data = ctx.template Alloc<T>(out);
T* out_data = dev_ctx.template Alloc<T>(out);

for (int i = 0; i < x_numel; i++) {
PADDLE_ENFORCE_LE(
min_data[i],
max_data[i],
errors::InvalidArgument("max should be greater than or equal to min. "));
out_data[i] = x_data[i] < min_data[i] ? min_data[i] : x_data[i] > max_data[i] ? max_data[i] : x;
out_data[i] = x_data[i] < min_data[i] ? min_data[i] : x_data[i] > max_data[i] ? max_data[i] : x_data[i];
}
}

Expand All @@ -48,4 +44,4 @@ PD_REGISTER_KERNEL(
clip, CPU, ALL_LAYOUT, phi::ClipKernel, float, double, int, int64_t) {}

PD_REGISTER_KERNEL(
clipwithtensor, CPU, ALL_LAYOUT, phi::ClipWithTensorKernel, float, double, int, int64_t) {}
clipwithtensor, CPU, ALL_LAYOUT, phi::ClipWithTensorKernel, float, double, int, int64_t) {}
28 changes: 18 additions & 10 deletions paddle/phi/kernels/gpu/clip_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,33 @@
namespace phi {

template <typename T>
class ClipWithTensorGradFunctor {
HOSTDEVICE T operator()(const T x, const T y, const T min_, const max_) const {
return (y > min_ && y < max_) ? x : static_cast<T>(0);
__global__ void ClipWithTensorGradFunctor(const int N, const T* out_grad, const T* x, const T* min, const T* max, T* x_grad) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < N; idx += blockDim.x * gridDim.x) {
x_grad[idx] = (x[idx] > min[idx]) && (x[idx] < max[idx]) ? out_grad[idx] : static_cast<T>(0);
}
};

template <typename T, typename Context>
void ClipWithTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& min,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad) {

std::vector<const DenseTensor*> ins = {&out_grad, &x, &min, &max};
std::vector<DenseTensor*> outs = {x_grad};
auto functor = ClipWithTensorGradFunctor<T>();
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor);
const T* x_data = x.data<T>();
auto numel = x.numel();
const T* min_data = min.data<T>();
const T* max_data = max.data<T>();
const T* out_grad_data = out_grad.data<T>();

T* x_grad_data = dev_ctx.template Alloc<T>(x_grad);

auto stream = dev_ctx.stream();
auto config = backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel);
ClipWithTensorGradFunctor<T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
numel, out_grad_data, x_data, min_data, max_data, x_grad_data);
}

}
Expand All @@ -65,4 +73,4 @@ PD_REGISTER_KERNEL(clipwithtensor_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {}
14 changes: 8 additions & 6 deletions paddle/phi/kernels/gpu/clip_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,35 @@

#include "paddle/phi/kernels/clip_kernel.h"

#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/clip_kernel_impl.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"

namespace phi {

template <typename T>
struct ClipWithTensorFunctor {
inline HOSTDEVICE T operator()(const bool x, const T min_, const T max_) const {
return x < min_ ? min_ : x > max_ ? max_ : x;
inline HOSTDEVICE T operator()(const T x, const T min_, const T max_) const {
return x < min_ ? min_ : (x > max_ ? max_ : x);
}
};

template <typename T, typename Context>
void ClipWithTensorKernel(const Context& ctx,
void ClipWithTensorKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out) {
std::vector<const DenseTensor*> ins = {&x, &min, &max};
std::vector<DenseTensor*> outs = {out};
ctx.template Alloc<T>(out);
dev_ctx.template Alloc<T>(out);

ClipWithTensorFunctor<T> func;
funcs::ElementwiseKernel<T, ClipWithTensorFunctor<T>, 1>(ctx, ins, &outs, func);
funcs::ElementwiseKernel<T, ClipWithTensorFunctor<T>, 1>(dev_ctx, ins, &outs, func);
}

} // namespace phi
Expand All @@ -65,4 +67,4 @@ PD_REGISTER_KERNEL(clipwithtensor,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {}
35 changes: 35 additions & 0 deletions paddle/phi/kernels/xpu/clip_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@

#include "paddle/phi/kernels/clip_grad_kernel.h"

#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_header.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/compare_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/where_kernel.h"

namespace phi {

Expand All @@ -38,6 +43,27 @@ void ClipGradKernel(const Context& ctx,
static_cast<XPUDataType>(max.to<T>()));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_grad");
}

template <typename T, typename Context>
void ClipWithTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
using XPUDataType = typename XPUTypeTrait<T>::Type;

DenseTensor min_tensor(phi::DataType::BOOL);
DenseTensor max_tensor(phi::DataType::BOOL);
LessThanKernel<T, Context>(dev_ctx, min, x, &min_tensor);
LessThanKernel<T, Context>(dev_ctx, x, max, &max_tensor);
DenseTensor out(phi::DataType::BOOL);
EqualKernel<T, Context>(dev_ctx, min_tensor, max_tensor, &out);
DenseTensor zero_tensor(x_grad->dtype());
FullKernel<T, Context>(dev_ctx, common::vectorize(x_grad->dims()), 0.0f, zero_tensor.dtype(), &zero_tensor);
WhereKernel<T, Context>(dev_ctx, out, out_grad, zero_tensor, x_grad);
}
} // namespace phi

PD_REGISTER_KERNEL(clip_grad,
Expand All @@ -48,3 +74,12 @@ PD_REGISTER_KERNEL(clip_grad,
phi::dtype::float16,
int64_t,
int) {}

PD_REGISTER_KERNEL(clipwithtensor_grad,
XPU,
ALL_LAYOUT,
phi::ClipWithTensorGradKernel,
float,
phi::dtype::float16,
int64_t,
int) {}
63 changes: 63 additions & 0 deletions paddle/phi/kernels/xpu/clip_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
#include "glog/logging.h"

#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_header.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/compare_kernel.h"
#include "paddle/phi/kernels/where_kernel.h"

namespace phi {

Expand Down Expand Up @@ -47,6 +50,56 @@ void ClipKernel(const Context& dev_ctx,
XPUAPIErrorMsg[r]));
}

template <typename T, typename Context>
void ClipWithTensorKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out) {
using XPUDataType = typename XPUTypeTrait<T>::Type;
const XPUDataType* x_data = reinterpret_cast<const XPUDataType*>(x.data<T>());
const XPUDataType* min_data = reinterpret_cast<const XPUDataType*>(min.data<T>());
const XPUDataType* max_data = reinterpret_cast<const XPUDataType*>(max.data<T>());
XPUDataType* out_data = reinterpret_cast<XPUDataType*>(dev_ctx.template Alloc<T>(out));

auto min_dims = common::vectorize<int>(min.dims());
if (min_dims.size() == 0) {
min_dims = std::vector<int>({1});
}
auto max_dims = common::vectorize<int>(max.dims());
if (max_dims.size() == 0) {
max_dims = std::vector<int>({1});
}

DenseTensor min_tensor(phi::DataType::BOOL);
LessThanKernel<T, Context>(dev_ctx, x, min, &min_tensor);

auto min_tensor_dims = common::vectorize<int>(min_tensor.dims());
if (min_tensor_dims.size() == 0) {
min_tensor_dims = std::vector<int>({1});
}

const bool* min_tensor_data = min_tensor.data<bool>();
int ret = xpu::select(
dev_ctx.x_context(), min_tensor_data, min_data, x_data, out_data, min_tensor_dims, min_dims);

PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu::select");

DenseTensor max_tensor(phi::DataType::BOOL);
LessThanKernel<T, Context>(dev_ctx, max, x, &max_tensor);

auto max_tensor_dims = common::vectorize<int>(max_tensor.dims());
if (max_tensor_dims.size() == 0) {
max_tensor_dims = std::vector<int>({1});
}

const bool* max_tensor_data = max_tensor.data<bool>();
int ret2 = xpu::select(
dev_ctx.x_context(), max_tensor_data, max_data, x_data, out_data, max_tensor_dims, max_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(ret2, "xpu::select");

}

} // namespace phi

PD_REGISTER_KERNEL(clip,
Expand All @@ -58,3 +111,13 @@ PD_REGISTER_KERNEL(clip,
phi::dtype::bfloat16,
int64_t,
int) {}

PD_REGISTER_KERNEL(clipwithtensor,
XPU,
ALL_LAYOUT,
phi::ClipWithTensorKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16,
int64_t,
int) {}
6 changes: 3 additions & 3 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,8 @@
inplace : (out_grad -> x_grad)

- backward_op : clipwithtensor_double_grad
forward : clipwithtensor_grad (Tensor x, Tensor grad_out, Tensor min, Tensor max) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_x_grad, Tensor min, Tensor max)
forward : clipwithtensor_grad (Tensor x, Tensor min, Tensor max, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor x, Tensor min, Tensor max, Tensor grad_x_grad)
output : Tensor(grad_out_grad)
infer_meta :
func : UnchangedInferMeta
Expand All @@ -415,7 +415,7 @@

- backward_op : clipwithtensor_grad
forward : clipwithtensor (Tensor x, Tensor min, Tensor max) -> Tensor(out)
args : (Tensor x, Tensor out_grad, Tensor min, Tensor)
args : (Tensor x, Tensor min, Tensor max, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
Expand Down
12 changes: 5 additions & 7 deletions paddle/phi/ops/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -602,20 +602,18 @@
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"]

- op : clipwithtensor
backward : clipwithtensor_grad, clipwithtensor_double_grad
- op : clip_by_norm
inputs :
x : X
min : Min
max : Max
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"]

- op : clip_by_norm
- op : clipwithtensor
backward : clipwithtensor_grad, clipwithtensor_double_grad
inputs :
x : X
min : Min
max : Max
outputs :
out : Out

Expand Down
22 changes: 11 additions & 11 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,17 @@
backward : clip_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : clip_by_norm
args : (Tensor x, float max_norm)
output : Tensor(out)
infer_meta :
func : ClipByNormInferMeta
kernel :
func : clip_by_norm {dense -> dense}
clip_by_norm_sr {selected_rows -> selected_rows}
interfaces : paddle::dialect::InferSymbolicShapeInterface
traits : paddle::dialect::ForwardOnlyTrait

- op : clipwithtensor
args : (Tensor x, Tensor min, Tensor max)
output : Tensor(out)
Expand All @@ -976,17 +987,6 @@
backward : clipwithtensor_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : clip_by_norm
args : (Tensor x, float max_norm)
output : Tensor(out)
infer_meta :
func : ClipByNormInferMeta
kernel :
func : clip_by_norm {dense -> dense}
clip_by_norm_sr {selected_rows -> selected_rows}
interfaces : paddle::dialect::InferSymbolicShapeInterface
traits : paddle::dialect::ForwardOnlyTrait

- op : coalesce_tensor
args : (Tensor[] input, DataType dtype, bool copy_data = false, bool set_constant = false, bool persist_output = false, float constant = 0.0, bool use_align = true, int align_size = -1, int size_of_dtype = -1, int64_t[] concated_shapes = {}, int64_t[] concated_ranks = {})
output : Tensor[](output){input.size()}, Tensor(fused_output)
Expand Down
Loading

0 comments on commit ade40cc

Please sign in to comment.