Skip to content

Commit

Permalink
change name to clipmul
Browse files Browse the repository at this point in the history
  • Loading branch information
a162837 committed Nov 25, 2024
1 parent cd2738f commit 81b1c78
Show file tree
Hide file tree
Showing 17 changed files with 272 additions and 261 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4375,6 +4375,32 @@ bool Where_OpInferSymbolicShape(pir::Operation *op,
return WhereOpInferSymbolicShape(op, infer_context);
}

bool ClipTensorOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape();
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)});

const std::vector<pir::Value> &operands = {
op->operand_source(0), op->operand_source(1), op->operand_source(2)};

size_t rank = x_shape.size();

for (size_t i = 0; i < rank; ++i) {
paddle::dialect::details::BuildCstrEqForTensorListAlongAxis(
infer_context, operands, i);
}

return true;
}

bool ClipTensor_OpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
return ClipTensorOpInferSymbolicShape(op, infer_context);
}

bool MultiplexOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &inputs_shape_or_data_list =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(WeightOnlyLinear)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(WeightedSampleNeighbors)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Where)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Where_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ClipTensor)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ClipTensor_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(YoloLoss)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeChannelWiseDequantizeMaxAbs)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MultiDot)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ OP_SAME_OPERANDS_AND_RESULT(Ceil_)
OP_SAME_OPERANDS_AND_RESULT(Celu)
OP_SAME_OPERANDS_AND_RESULT(Clip)
OP_SAME_OPERANDS_AND_RESULT(Clip_)
OP_SAME_OPERANDS_AND_RESULT(Clipmul_)
OP_SAME_OPERANDS_AND_RESULT(Clipmul_)
OP_SAME_OPERANDS_AND_RESULT(Conj)
OP_SAME_OPERANDS_AND_RESULT(CopyTo)
OP_SAME_OPERANDS_AND_RESULT(Cos)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Ceil_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Celu)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Clip)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Clip_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Clipmul_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conj)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CopyTo)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cos)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/clip_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void ClipGradKernel(const Context& dev_ctx,
DenseTensor* x_grad);

template <typename T, typename Context>
void ClipMulGradKernel(const Context& dev_ctx,
void ClipTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/clip_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,11 @@ void ClipMulKernel(const Context& dev_ctx,
const DenseTensor& max,
DenseTensor* out);

template <typename T, typename Context>
void ClipTensorKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out);

} // namespace phi
6 changes: 3 additions & 3 deletions paddle/phi/kernels/cpu/clip_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
namespace phi {

template <typename T, typename Context>
void ClipMulGradKernel(const Context& dev_ctx,
void ClipTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
Expand Down Expand Up @@ -50,10 +50,10 @@ PD_REGISTER_KERNEL(clip_grad,
int,
int64_t) {}

PD_REGISTER_KERNEL(clipmul_grad,
PD_REGISTER_KERNEL(clip_tensor_grad,
CPU,
ALL_LAYOUT,
phi::ClipMulGradKernel,
phi::ClipTensorGradKernel,
float,
double,
int,
Expand Down
7 changes: 4 additions & 3 deletions paddle/phi/kernels/cpu/clip_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
namespace phi {

template <typename T, typename Context>
void ClipMulKernel(const Context& dev_ctx,
void ClipTensorKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
Expand All @@ -34,7 +34,8 @@ void ClipMulKernel(const Context& dev_ctx,
T* out_data = dev_ctx.template Alloc<T>(out);

for (int i = 0; i < x_numel; i++) {
out_data[i] = x_data[i] < min_data[i] ? min_data[i] : x_data[i] > max_data[i] ? max_data[i] : x_data[i];
out_data[i] = x_data[i] < min_data[i] ? min_data[i] : x_data[i];
out_data[i] = out_data[i] > max_data[i] ? max_data[i] : out_data[i];
}
}

Expand All @@ -44,4 +45,4 @@ PD_REGISTER_KERNEL(
clip, CPU, ALL_LAYOUT, phi::ClipKernel, float, double, int, int64_t) {}

PD_REGISTER_KERNEL(
clipmul, CPU, ALL_LAYOUT, phi::ClipMulKernel, float, double, int, int64_t) {}
clip_tensor, CPU, ALL_LAYOUT, phi::ClipTensorKernel, float, double, int, int64_t) {}
10 changes: 5 additions & 5 deletions paddle/phi/kernels/gpu/clip_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
namespace phi {

template <typename T>
__global__ void ClipMulGradFunctor(const int N, const T* out_grad, const T* x, const T* min, const T* max, T* x_grad) {
__global__ void ClipTensorGradFunctor(const int N, const T* out_grad, const T* x, const T* min, const T* max, T* x_grad) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < N; idx += blockDim.x * gridDim.x) {
x_grad[idx] = (x[idx] > min[idx]) && (x[idx] < max[idx]) ? out_grad[idx] : static_cast<T>(0);
}
};

template <typename T, typename Context>
void ClipMulGradKernel(const Context& dev_ctx,
void ClipTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
Expand All @@ -48,7 +48,7 @@ void ClipMulGradKernel(const Context& dev_ctx,

auto stream = dev_ctx.stream();
auto config = backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel);
ClipMulGradFunctor<T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
ClipTensorGradFunctor<T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
numel, out_grad_data, x_data, min_data, max_data, x_grad_data);
}

Expand All @@ -64,10 +64,10 @@ PD_REGISTER_KERNEL(clip_grad,
phi::dtype::bfloat16,
phi::dtype::float16) {}

PD_REGISTER_KERNEL(clipmul_grad,
PD_REGISTER_KERNEL(clip_tensor_grad,
GPU,
ALL_LAYOUT,
phi::ClipMulGradKernel,
phi::ClipTensorGradKernel,
float,
double,
int,
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/kernels/gpu/clip_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
namespace phi {

template <typename T>
struct ClipMulFunctor {
struct ClipTensorFunctor {
inline HOSTDEVICE T operator()(const T x, const T min_, const T max_) const {
return x < min_ ? min_ : (x > max_ ? max_ : x);
}
};

template <typename T, typename Context>
void ClipMulKernel(const Context& dev_ctx,
void ClipTensorKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
Expand All @@ -41,8 +41,8 @@ void ClipMulKernel(const Context& dev_ctx,
std::vector<DenseTensor*> outs = {out};
dev_ctx.template Alloc<T>(out);

ClipMulFunctor<T> func;
funcs::ElementwiseKernel<T, ClipMulFunctor<T>, 1>(dev_ctx, ins, &outs, func);
ClipTensorFunctor<T> func;
funcs::ElementwiseKernel<T, ClipTensorFunctor<T>, 1>(dev_ctx, ins, &outs, func);
}

} // namespace phi
Expand All @@ -58,10 +58,10 @@ PD_REGISTER_KERNEL(clip,
phi::dtype::float16,
phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(clipmul,
PD_REGISTER_KERNEL(clip_tensor,
GPU,
ALL_LAYOUT,
phi::ClipMulKernel,
phi::ClipTensorKernel,
float,
double,
int,
Expand Down
7 changes: 3 additions & 4 deletions paddle/phi/kernels/xpu/clip_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,13 @@ void ClipGradKernel(const Context& ctx,
}

template <typename T, typename Context>
void ClipMulGradKernel(const Context& dev_ctx,
void ClipTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
using XPUDataType = typename XPUTypeTrait<T>::Type;

DenseTensor min_tensor(phi::DataType::BOOL);
DenseTensor max_tensor(phi::DataType::BOOL);
Expand All @@ -75,10 +74,10 @@ PD_REGISTER_KERNEL(clip_grad,
int64_t,
int) {}

PD_REGISTER_KERNEL(clipmul_grad,
PD_REGISTER_KERNEL(clip_tensor_grad,
XPU,
ALL_LAYOUT,
phi::ClipMulGradKernel,
phi::ClipTensorGradKernel,
float,
phi::dtype::float16,
int64_t,
Expand Down
14 changes: 7 additions & 7 deletions paddle/phi/kernels/xpu/clip_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void ClipKernel(const Context& dev_ctx,
}

template <typename T, typename Context>
void ClipMulKernel(const Context& dev_ctx,
void ClipTensorKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
Expand All @@ -61,7 +61,7 @@ void ClipMulKernel(const Context& dev_ctx,
const XPUDataType* min_data = reinterpret_cast<const XPUDataType*>(min.data<T>());
const XPUDataType* max_data = reinterpret_cast<const XPUDataType*>(max.data<T>());
XPUDataType* out_data = reinterpret_cast<XPUDataType*>(dev_ctx.template Alloc<T>(out));

auto min_dims = common::vectorize<int>(min.dims());
if (min_dims.size() == 0) {
min_dims = std::vector<int>({1});
Expand All @@ -70,7 +70,7 @@ void ClipMulKernel(const Context& dev_ctx,
if (max_dims.size() == 0) {
max_dims = std::vector<int>({1});
}

DenseTensor min_tensor(phi::DataType::BOOL);
LessThanKernel<T, Context>(dev_ctx, x, min, &min_tensor);

Expand All @@ -97,7 +97,7 @@ void ClipMulKernel(const Context& dev_ctx,
int ret2 = xpu::select(
dev_ctx.x_context(), max_tensor_data, max_data, x_data, out_data, max_tensor_dims, max_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(ret2, "xpu::select");

}

} // namespace phi
Expand All @@ -112,12 +112,12 @@ PD_REGISTER_KERNEL(clip,
int64_t,
int) {}

PD_REGISTER_KERNEL(clipmul,
PD_REGISTER_KERNEL(clip_tensor,
XPU,
ALL_LAYOUT,
phi::ClipMulKernel,
phi::ClipTensorKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16,
int64_t,
int) {}
int) {}
26 changes: 13 additions & 13 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -402,28 +402,28 @@
backward : clip_double_grad
inplace : (out_grad -> x_grad)

- backward_op : clipmul_double_grad
forward : clipmul_grad (Tensor x, Tensor min, Tensor max, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor x, Tensor min, Tensor max, Tensor grad_x_grad)
output : Tensor(grad_out_grad)
- backward_op : clip_tensor_grad
forward : clip_tensor (Tensor x, Tensor min, Tensor max) -> Tensor(out)
args : (Tensor x, Tensor min, Tensor max, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : clipmul_grad
data_type : x
func : clip_tensor_grad
backward : clip_tensor_double_grad
inplace : (out_grad -> x_grad)

- backward_op : clipmul_grad
forward : clipmul (Tensor x, Tensor min, Tensor max) -> Tensor(out)
args : (Tensor x, Tensor min, Tensor max, Tensor out_grad)
output : Tensor(x_grad)
- backward_op : clip_tensor_double_grad
forward : clip_tensor_grad (Tensor x, Tensor min, Tensor max, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor x, Tensor min, Tensor max, Tensor grad_x_grad)
output : Tensor(grad_out_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : clipmul_grad
backward : clipmul_double_grad
inplace : (out_grad -> x_grad)
func : clip_tensor_grad
data_type : x

- backward_op : complex_grad
forward : complex (Tensor real, Tensor imag) -> Tensor(out)
Expand Down
9 changes: 0 additions & 9 deletions paddle/phi/ops/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -608,15 +608,6 @@
outputs :
out : Out

- op : clipmul
backward : clipmul_grad, clipmul_double_grad
inputs :
x : X
min : Min
max : Max
outputs :
out : Out

- op : coalesce_tensor
inputs :
{input : Input}
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -975,17 +975,17 @@
interfaces : paddle::dialect::InferSymbolicShapeInterface
traits : paddle::dialect::ForwardOnlyTrait

- op : clipmul
- op : clip_tensor
args : (Tensor x, Tensor min, Tensor max)
output : Tensor(out)
inplace : (x -> out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : clipmul
func : clip_tensor
data_type : x
backward : clipmul_grad
inplace : (x -> out)
backward : clip_tensor_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : coalesce_tensor
Expand Down
Loading

0 comments on commit 81b1c78

Please sign in to comment.