Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 7th No.28】为 paddle.clip 进行功能增强 #69269

Open
wants to merge 53 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
23adff1
test
a162837 Nov 10, 2024
5481d22
Merge branch 'PaddlePaddle:develop' into TensorClip
a162837 Nov 10, 2024
b3f24f8
add cpu and gpu
a162837 Nov 11, 2024
fe7a239
Merge branch 'TensorClip' of https://github.com/a162837/Paddle into T…
a162837 Nov 11, 2024
ade40cc
delete min compare with max
a162837 Nov 11, 2024
eb702fa
Merge branch 'PaddlePaddle:develop' into TensorClip
a162837 Nov 12, 2024
491700a
change name to clipmul
a162837 Nov 13, 2024
aadd345
Merge branch 'TensorClip' of https://github.com/a162837/Paddle into T…
a162837 Nov 13, 2024
8ad4626
change name to clipmul
a162837 Nov 13, 2024
ccf7347
change name to clipmul
a162837 Nov 13, 2024
78f72ba
change name to clipmul
a162837 Nov 13, 2024
6a806f2
Merge branch 'PaddlePaddle:develop' into TensorClip
a162837 Nov 13, 2024
8b31397
change name to clipmul
a162837 Nov 13, 2024
cd2738f
Merge branch 'TensorClip' of https://github.com/a162837/Paddle into T…
a162837 Nov 13, 2024
719307b
change name to clipmul
a162837 Nov 14, 2024
808139c
fix codestyle
a162837 Dec 3, 2024
ed4ca4a
add test
a162837 Dec 3, 2024
f5d638a
add c++
a162837 Dec 4, 2024
24cf28c
fix codestyle
a162837 Dec 5, 2024
ab6a032
fix codestyle
a162837 Dec 5, 2024
574e703
add test
a162837 Dec 6, 2024
b5400e3
fix bug
a162837 Dec 6, 2024
e9ecc08
change name to clipmul
a162837 Nov 14, 2024
86ed3ad
Merge branch 'TensorClip' of https://github.com/a162837/Paddle into T…
a162837 Dec 6, 2024
d2d28ee
Merge branch 'PaddlePaddle:develop' into TensorClip
a162837 Dec 6, 2024
c8e5ba5
fix bug
a162837 Dec 6, 2024
93b84fe
Merge branch 'TensorClip' of https://github.com/a162837/Paddle into T…
a162837 Dec 6, 2024
9d23564
add
a162837 Dec 6, 2024
39b2429
add
a162837 Dec 7, 2024
2e716c3
add
a162837 Dec 7, 2024
06f3562
add
a162837 Dec 7, 2024
cc7b1ce
add
a162837 Dec 7, 2024
509d25f
add
a162837 Dec 7, 2024
6b73b98
add
a162837 Dec 7, 2024
0ecfe86
add
a162837 Dec 7, 2024
d4d6d02
add
a162837 Dec 8, 2024
32c4bf8
add
a162837 Dec 8, 2024
513b21f
add
a162837 Dec 8, 2024
8d784f3
add
a162837 Dec 8, 2024
e8d1f84
add
a162837 Dec 8, 2024
6d816fa
add
a162837 Dec 8, 2024
e3f16ed
add
a162837 Dec 8, 2024
b42f572
add
a162837 Dec 9, 2024
3cbb875
add
a162837 Dec 9, 2024
4d349a7
add
a162837 Dec 10, 2024
ed6a94b
add
a162837 Dec 6, 2024
2b10e2d
add
a162837 Dec 10, 2024
ec95360
add
a162837 Dec 10, 2024
75201df
add cpu gpu xpu
a162837 Dec 25, 2024
fd66903
Merge branch 'develop' into TensorClip
a162837 Dec 25, 2024
32de0d4
add cpu gpu xpu
a162837 Dec 25, 2024
1556896
Merge branch 'PaddlePaddle:develop' into TensorClip
a162837 Dec 30, 2024
b2059ec
add python test
a162837 Dec 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions paddle/phi/kernels/clip_tensor_grad_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"

namespace phi {

template <typename T, typename Context>
void ClipTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad);

} // namespace phi
29 changes: 29 additions & 0 deletions paddle/phi/kernels/clip_tensor_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"

namespace phi {

template <typename T, typename Context>
void ClipTensorKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out);

} // namespace phi
56 changes: 56 additions & 0 deletions paddle/phi/kernels/cpu/clip_tensor_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/clip_tensor_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"

namespace phi {

template <typename T, typename Context>
void ClipTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
DenseTensor tem_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
DenseTensor tem_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());

const T* x_data = x.data<T>();
const T* min_data = tem_min.data<T>();
const T* max_data = tem_max.data<T>();
auto numel = x.numel();
auto* dout = out_grad.data<T>();

auto* dx = dev_ctx.template Alloc<T>(x_grad);
for (int i = 0; i < numel; i++) {
dx[i] = (x_data[i] > min_data[i] && x_data[i] < max_data[i])
? dout[i]
: static_cast<T>(0);
}
}

} // namespace phi

PD_REGISTER_KERNEL(clip_tensor_grad,
CPU,
ALL_LAYOUT,
phi::ClipTensorGradKernel,
float,
double,
int,
int64_t) {}
56 changes: 56 additions & 0 deletions paddle/phi/kernels/cpu/clip_tensor_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/clip_tensor_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"

namespace phi {

template <typename T, typename Context>
void ClipTensorKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out) {
DenseTensor tem_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
DenseTensor tem_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());

const T* x_data = x.data<T>();
const T* min_data = tem_min.data<T>();
const T* max_data = tem_max.data<T>();

auto x_numel = x.numel();

T* out_data = dev_ctx.template Alloc<T>(out);

for (int i = 0; i < x_numel; i++) {
out_data[i] = x_data[i] < min_data[i] ? min_data[i] : x_data[i];
out_data[i] = out_data[i] > max_data[i] ? max_data[i] : out_data[i];
}
}

} // namespace phi

PD_REGISTER_KERNEL(clip_tensor,
CPU,
ALL_LAYOUT,
phi::ClipTensorKernel,
float,
double,
int,
int64_t) {}
76 changes: 76 additions & 0 deletions paddle/phi/kernels/gpu/clip_tensor_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/clip_tensor_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"

namespace phi {

template <typename T>
__global__ void ClipTensorGradFunctor(const int N,
const T* out_grad,
const T* x,
const T* min,
const T* max,
T* x_grad) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < N; idx += blockDim.x * gridDim.x) {
x_grad[idx] = (x[idx] > min[idx]) && (x[idx] < max[idx])
? out_grad[idx]
: static_cast<T>(0);
}
}

template <typename T, typename Context>
void ClipTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
DenseTensor tem_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
DenseTensor tem_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());

const T* x_data = x.data<T>();
auto numel = x.numel();
const T* min_data = tem_min.data<T>();
const T* max_data = tem_max.data<T>();
const T* out_grad_data = out_grad.data<T>();

T* x_grad_data = dev_ctx.template Alloc<T>(x_grad);

auto stream = dev_ctx.stream();
auto config = backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel);
ClipTensorGradFunctor<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
numel, out_grad_data, x_data, min_data, max_data, x_grad_data);
}

} // namespace phi

PD_REGISTER_KERNEL(clip_tensor_grad,
GPU,
ALL_LAYOUT,
phi::ClipTensorGradKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
65 changes: 65 additions & 0 deletions paddle/phi/kernels/gpu/clip_tensor_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/clip_tensor_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"

namespace phi {

template <typename T>
struct ClipTensorFunctor {
inline HOSTDEVICE T operator()(const T x, const T min_, const T max_) const {
T x_ = x < min_ ? min_ : x;
T x__ = x_ > max_ ? max_ : x_;
return x__;
}
};

template <typename T, typename Context>
void ClipTensorKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out) {
DenseTensor tem_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
DenseTensor tem_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());

std::vector<const DenseTensor*> ins = {&x, &tem_min, &tem_max};
std::vector<DenseTensor*> outs = {out};
dev_ctx.template Alloc<T>(out);

ClipTensorFunctor<T> func;
funcs::ElementwiseKernel<T, ClipTensorFunctor<T>, 1>(
dev_ctx, ins, &outs, func);
}

} // namespace phi

PD_REGISTER_KERNEL(clip_tensor,
GPU,
ALL_LAYOUT,
phi::ClipTensorKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
77 changes: 77 additions & 0 deletions paddle/phi/kernels/xpu/clip_tensor_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/clip_tensor_grad_kernel.h"

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/compare_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/logical_kernel.h"
#include "paddle/phi/kernels/where_kernel.h"

namespace phi {

template <typename T, typename Context>
void ClipTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
DenseTensor ex_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
DenseTensor ex_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());

phi::DenseTensor x_ls_min;
MetaTensor meta_x_ls_min(&x_ls_min);
UnchangedExceptDtypeInferMeta(x, &meta_x_ls_min);
meta_x_ls_min.set_dtype(phi::DataType::BOOL);
phi::LessThanKernel<T, Context>(dev_ctx, ex_min, x, &x_ls_min);

phi::DenseTensor x_ls_max;
MetaTensor meta_x_ls_max(&x_ls_max);
UnchangedExceptDtypeInferMeta(x, &meta_x_ls_max);
meta_x_ls_max.set_dtype(phi::DataType::BOOL);
phi::LessThanKernel<T, Context>(dev_ctx, x, ex_max, &x_ls_max);

phi::DenseTensor out;
MetaTensor meta_out(&out);
UnchangedExceptDtypeInferMeta(x, &meta_out);
meta_out.set_dtype(phi::DataType::BOOL);
phi::LogicalAndKernel<bool, Context>(dev_ctx, x_ls_min, x_ls_max, &out);

phi::DenseTensor zero_tensor;
MetaTensor meta_zero(&zero_tensor);
UnchangedInferMeta(x_grad, &meta_zero);
phi::FullKernel<T, Context>(dev_ctx,
common::vectorize(x_grad->dims()),
0.0f,
zero_tensor.dtype(),
&zero_tensor);
phi::WhereKernel<T, Context>(dev_ctx, out, out_grad, zero_tensor, x_grad);
}

} // namespace phi

PD_REGISTER_KERNEL(clip_tensor_grad,
XPU,
ALL_LAYOUT,
phi::ClipTensorGradKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16,
int64_t,
int) {}
Loading
Loading