Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
a162837 committed Dec 25, 2024
1 parent 2b10e2d commit ec95360
Show file tree
Hide file tree
Showing 31 changed files with 825 additions and 632 deletions.
36 changes: 36 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,42 @@ void ChannelShuffleGradInferMeta(const MetaTensor& out_grad,
x_grad->set_dtype(out_grad.dtype());
}

void ClipTensorGradInferMeta(const MetaTensor& x,
const MetaTensor& min,
const MetaTensor& max,
const MetaTensor& out_grad,
MetaTensor* x_grad) {
auto x_dims = x.dims();
auto min_dims = min.dims();
auto max_dims = max.dims();

if (common::product(x_dims) >= common::product(min_dims) && common::product(x_dims) >= common::product(max_dims)) {
PADDLE_ENFORCE_EQ(
out_grad.dims(),
x.dims(),
errors::InvalidArgument(
"Gradients and its expand input should have the same shape."));
x_grad->set_dims(x.dims());
}
else if (common::product(min_dims) >= common::product(x_dims) && common::product(min_dims) >= common::product(max_dims)) {
PADDLE_ENFORCE_EQ(
out_grad.dims(),
min.dims(),
errors::InvalidArgument(
"Gradients and its expand input should have the same shape."));
x_grad->set_dims(min.dims());
}
else {
PADDLE_ENFORCE_EQ(
out_grad.dims(),
max.dims(),
errors::InvalidArgument(
"Gradients and its expand input should have the same shape."));
x_grad->set_dims(max.dims());
}
x_grad->set_dtype(x.dtype());
}

void ComplexGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& dout,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ void ChannelShuffleGradInferMeta(const MetaTensor& out_grad,
const std::string& data_format,
MetaTensor* x_grad);

void ClipTensorGradInferMeta(const MetaTensor& x,
const MetaTensor& min,
const MetaTensor& max,
const MetaTensor& out_grad,
MetaTensor* x_grad);

void ComplexGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& dout,
Expand Down
21 changes: 21 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,27 @@ void BoxCoderInferMeta(const MetaTensor& prior_box,
output_box->set_dtype(target_box.dtype());
}

void ClipTensorInferMeta(const MetaTensor& x,
const MetaTensor& min,
const MetaTensor& max,
MetaTensor* out) {

auto x_dims = x.dims();
auto min_dims = min.dims();
auto max_dims = max.dims();

if (common::product(x_dims) >= common::product(min_dims) && common::product(x_dims) >= common::product(max_dims)) {
out->set_dims(x.dims());
}
else if (common::product(min_dims) >= common::product(x_dims) && common::product(min_dims) >= common::product(max_dims)) {
out->set_dims(min.dims());
}
else if (common::product(max_dims) >= common::product(x_dims) && common::product(max_dims) >= common::product(min_dims)) {
out->set_dims(max.dims());
}
out->set_dtype(x.dtype());
}

void DistributedPushSparseInferMeta(
const std::vector<const MetaTensor*>& ids,
const std::vector<const MetaTensor*>& shows,
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ void BoxCoderInferMeta(const MetaTensor& prior_box,
MetaTensor* output_box,
MetaConfig config = MetaConfig());

void ClipTensorInferMeta(const MetaTensor& x,
const MetaTensor& min,
const MetaTensor& max,
MetaTensor* out);

void CollectFpnProposalsInferMeta(
const std::vector<const MetaTensor*>& multi_level_rois,
const std::vector<const MetaTensor*>& multi_level_scores,
Expand Down
8 changes: 0 additions & 8 deletions paddle/phi/kernels/clip_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,4 @@ void ClipGradKernel(const Context& dev_ctx,
const Scalar& max,
DenseTensor* x_grad);

template <typename T, typename Context>
void ClipTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad);

} // namespace phi
7 changes: 0 additions & 7 deletions paddle/phi/kernels/clip_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,4 @@ void ClipKernel(const Context& dev_ctx,
const Scalar& max,
DenseTensor* out);

template <typename T, typename Context>
void ClipTensorKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out);

} // namespace phi
31 changes: 31 additions & 0 deletions paddle/phi/kernels/clip_tensor_grad_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/kernels/expand_kernel.h"

namespace phi {

template <typename T, typename Context>
void ClipTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad);

} // namespace phi
30 changes: 30 additions & 0 deletions paddle/phi/kernels/clip_tensor_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/kernels/expand_kernel.h"

namespace phi {

template <typename T, typename Context>
void ClipTensorKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out);

} // namespace phi
34 changes: 0 additions & 34 deletions paddle/phi/kernels/cpu/clip_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,6 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/clip_grad_kernel_impl.h"

namespace phi {

template <typename T, typename Context>
void ClipTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
const T* x_data = x.data<T>();
const T* min_data = min.data<T>();
const T* max_data = max.data<T>();
auto numel = x.numel();
auto* dout = out_grad.data<T>();

auto* dx = dev_ctx.template Alloc<T>(x_grad);
for (int i = 0; i < numel; i++) {
dx[i] = (x_data[i] > min_data[i] && x_data[i] < max_data[i])
? dout[i]
: static_cast<T>(0);
}
}

} // namespace phi

PD_REGISTER_KERNEL(clip_grad,
CPU,
ALL_LAYOUT,
Expand All @@ -51,12 +26,3 @@ PD_REGISTER_KERNEL(clip_grad,
double,
int,
int64_t) {}

PD_REGISTER_KERNEL(clip_tensor_grad,
CPU,
ALL_LAYOUT,
phi::ClipTensorGradKernel,
float,
double,
int,
int64_t) {}
32 changes: 0 additions & 32 deletions paddle/phi/kernels/cpu/clip_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,5 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/clip_kernel_impl.h"

namespace phi {

template <typename T, typename Context>
void ClipTensorKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out) {
const T* x_data = x.data<T>();
const T* min_data = min.data<T>();
const T* max_data = max.data<T>();
auto x_numel = x.numel();

T* out_data = dev_ctx.template Alloc<T>(out);

for (int i = 0; i < x_numel; i++) {
out_data[i] = x_data[i] < min_data[i] ? min_data[i] : x_data[i];
out_data[i] = out_data[i] > max_data[i] ? max_data[i] : out_data[i];
}
}

} // namespace phi

PD_REGISTER_KERNEL(
clip, CPU, ALL_LAYOUT, phi::ClipKernel, float, double, int, int64_t) {}

PD_REGISTER_KERNEL(clip_tensor,
CPU,
ALL_LAYOUT,
phi::ClipTensorKernel,
float,
double,
int,
int64_t) {}
80 changes: 80 additions & 0 deletions paddle/phi/kernels/cpu/clip_tensor_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/clip_tensor_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/expand_kernel.h"

namespace phi {

template <typename T, typename Context>
void ClipTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad) {

DenseTensor ex_min;
DenseTensor ex_max;
DenseTensor ex_x;
std::vector<int> real_target_shape = common::vectorize<int>(x_grad->dims());
if (x.dims() != x_grad->dims()) {
phi::ExpandKernel<T, Context>(
dev_ctx, x, real_target_shape, &ex_x);
} else {
ex_x = x;
}
if (min.dims() != x_grad->dims()) {
phi::ExpandKernel<T, Context>(
dev_ctx, min, real_target_shape, &ex_min);
} else {
ex_min = min;
}
if (max.dims() != x_grad->dims()) {
phi::ExpandKernel<T, Context>(
dev_ctx, max, real_target_shape, &ex_max);
} else {
ex_max = max;
}
phi::CastKernel<T, Context>(dev_ctx, ex_min, ex_x.dtype(), &ex_min);
phi::CastKernel<T, Context>(dev_ctx, ex_max, ex_x.dtype(), &ex_max);

const T* x_data = ex_x.data<T>();
const T* min_data = ex_min.data<T>();
const T* max_data = ex_max.data<T>();
auto numel = ex_x.numel();
auto* dout = out_grad.data<T>();

auto* dx = dev_ctx.template Alloc<T>(x_grad);
for (int i = 0; i < numel; i++) {
dx[i] = (x_data[i] > min_data[i] && x_data[i] < max_data[i])
? dout[i]
: static_cast<T>(0);
}
}

} // namespace phi

PD_REGISTER_KERNEL(clip_tensor_grad,
CPU,
ALL_LAYOUT,
phi::ClipTensorGradKernel,
float,
double,
int,
int64_t) {}
Loading

0 comments on commit ec95360

Please sign in to comment.