Skip to content

Commit

Permalink
Delete ElementwiseKernel in BroadcastKernel (#42779)
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG authored May 20, 2022
1 parent c5d3bc0 commit 0d878f1
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 48 deletions.
22 changes: 6 additions & 16 deletions paddle/phi/kernels/funcs/broadcast_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,26 +585,16 @@ void BroadcastKernel(const KPDevice &ctx,
Functor func) {
std::vector<int> dims_size;
dims_size.reserve(ins.size());
bool no_broadcast_flag = true;
for (auto *in : ins) {
no_broadcast_flag &= ins[0]->dims() == in->dims();
dims_size.emplace_back(in->dims().size());
}

if (ins.size() > 0 && outs->size() > 0) {
no_broadcast_flag &= outs->at(0)->dims() == ins[0]->dims();
}

if (no_broadcast_flag) {
phi::funcs::ElementwiseKernel<OutT, Functor, NumOuts>(ctx, ins, outs, func);
} else {
axis = axis == -1
? *std::max_element(dims_size.begin(), dims_size.end()) -
*std::min_element(dims_size.begin(), dims_size.end())
: axis;
BroadcastKernelForDifferentVecSize<ET, InT, OutT, Functor, NumOuts>(
ctx, ins, outs, axis, func);
}
axis = axis == -1
? *std::max_element(dims_size.begin(), dims_size.end()) -
*std::min_element(dims_size.begin(), dims_size.end())
: axis;
BroadcastKernelForDifferentVecSize<ET, InT, OutT, Functor, NumOuts>(
ctx, ins, outs, axis, func);
}

template <typename Functor, typename T, typename OutType = T>
Expand Down
10 changes: 6 additions & 4 deletions paddle/phi/kernels/gpu/gelu_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@ void GeluGradKernel(const Context& dev_ctx,
}
}
#endif
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithApproximateGradFunctor<T>());
using Functor = GeluWithApproximateGradFunctor<T>;
phi::funcs::ElementwiseKernel<T, Functor, 1>(
dev_ctx, ins, &outs, Functor());
} else {
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithoutApproximateGradFunctor<T>());
using Functor = GeluWithoutApproximateGradFunctor<T>;
phi::funcs::ElementwiseKernel<T, Functor, 1>(
dev_ctx, ins, &outs, Functor());
}
}

Expand Down
10 changes: 6 additions & 4 deletions paddle/phi/kernels/gpu/gelu_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,13 @@ void GeluKernel(const Context& dev_ctx,
}
}
#endif
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithApproximateFunctor<T>());
using Functor = GeluWithApproximateFunctor<T>;
phi::funcs::ElementwiseKernel<T, Functor, 1>(
dev_ctx, ins, &outs, Functor());
} else {
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithoutApproximateFunctor<T>());
using Functor = GeluWithoutApproximateFunctor<T>;
phi::funcs::ElementwiseKernel<T, Functor, 1>(
dev_ctx, ins, &outs, Functor());
}
}

Expand Down
21 changes: 7 additions & 14 deletions paddle/phi/kernels/gpu/reduce_grad.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,19 @@ void ReduceGrad(const GPUContext& dev_ctx,
}));
}

template <typename T,
typename Context,
template <typename, typename> class TransformOp>
template <typename T, typename OutT, typename Context, typename Functor>
void ReduceGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
DenseTensor* x_grad,
Functor functor) {
auto* in_x = &x;
auto* d_out = &out_grad;
auto* d_x = x_grad;

auto pt_out_dtype = x.dtype();

// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size();
std::vector<int> reduce_dims =
Expand All @@ -79,14 +76,10 @@ void ReduceGradKernel(const Context& dev_ctx,

auto pt_d_out = new_d_out;
auto pt_d_x = *d_x;
using MPType = typename kps::details::MPTypeTrait<T>::Type;

phi::ReduceGrad<T, TransformOp<T, MPType>>(
dev_ctx,
&pt_d_out,
&pt_d_x,
pt_out_dtype,
TransformOp<T, MPType>(reduce_num));
std::vector<const DenseTensor*> inputs = {&pt_d_out};
std::vector<DenseTensor*> outputs = {&pt_d_x};
funcs::BroadcastKernel<phi::ElementwiseType::kUnary, T, OutT>(
dev_ctx, inputs, &outputs, 0, functor);
}

} // namespace phi
Expand Down
19 changes: 17 additions & 2 deletions paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,23 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
ReduceGradKernel<T, Context, kps::DivideFunctor>(
dev_ctx, x, out_grad, dims, keep_dim, reduce_all, x_grad);
int dim_size = x.dims().size();
std::vector<int> reduce_dims =
funcs::details::GetReduceDim(dims, dim_size, reduce_all);
int reduce_num = 1;
for (auto i : reduce_dims) {
reduce_num *= (x.dims())[i];
}
using MPType = typename kps::details::MPTypeTrait<T>::Type;
ReduceGradKernel<T, T, Context, kps::DivideFunctor<T, MPType>>(
dev_ctx,
x,
out_grad,
dims,
keep_dim,
reduce_all,
x_grad,
kps::DivideFunctor<T, MPType>(reduce_num));
}

} // namespace phi
Expand Down
37 changes: 34 additions & 3 deletions paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,40 @@ void ReduceSumGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
ReduceGradKernel<T, Context, kps::IdentityFunctor>(
dev_ctx, x, out_grad, dims, keep_dim, reduce_all, x_grad);
using MPType = typename kps::details::MPTypeTrait<T>::Type;
auto out_dtype = x.dtype();
auto* in_x = &x;
auto* d_out = &out_grad;
auto* d_x = x_grad;

// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size();
std::vector<int> reduce_dims =
funcs::details::GetReduceDim(dims, dim_size, reduce_all);

auto update_dims = vectorize(d_x->dims());
int reduce_num = 1;
for (auto i : reduce_dims) {
reduce_num *= (in_x->dims())[i];
update_dims[i] = 1;
}
// make new tensor
DenseTensor new_d_out(d_out->dtype());
new_d_out.ShareDataWith(*d_out);
new_d_out.Resize(phi::make_ddim(update_dims));

dev_ctx.Alloc(d_x, x.dtype());
auto pt_out_dtype = x.dtype();
auto pt_d_out = new_d_out;
auto pt_d_x = *d_x;
std::vector<const DenseTensor*> inputs = {&pt_d_out};
std::vector<DenseTensor*> outputs = {&pt_d_x};
phi::ReduceGrad<T, kps::IdentityFunctor<T, MPType>>(
dev_ctx,
&pt_d_out,
&pt_d_x,
pt_out_dtype,
kps::IdentityFunctor<T, MPType>());
}

} // namespace phi
Expand All @@ -48,4 +80,3 @@ PD_REGISTER_KERNEL(sum_grad,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

3 changes: 1 addition & 2 deletions paddle/phi/kernels/gpu/where_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ void WhereKernel(const Context& ctx,
ctx.template Alloc<T>(out);

CondFunctor<T> func;
funcs::BroadcastKernel<ElementwiseType::kTernary, T, T>(
ctx, ins, &outs, -1, func);
funcs::ElementwiseKernel<T, CondFunctor<T>, 1>(ctx, ins, &outs, func);
}

} // namespace phi
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/kps/bitwise_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ void BitwiseNotKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
funcs::BitwiseNotFunctor<T> func;
funcs::BroadcastKernel<ElementwiseType::kUnary, T, T>(
dev_ctx, ins, &outs, -1, func);
funcs::BitwiseNotFunctor<T> unary_func;
funcs::ElementwiseKernel<T, funcs::BitwiseNotFunctor<T>>(
dev_ctx, ins, &outs, unary_func);
}

} // namespace phi
Expand Down

0 comments on commit 0d878f1

Please sign in to comment.