Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions paddle/phi/kernels/funcs/reduce_grad_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ void ReduceGradFunctor(const Context& dev_ctx,
auto x_dims = input0.dims();
auto reduced_dims_v = common::vectorize(x_dims);
std::vector<int> dims_ref = dims;
Eigen::array<int, D> broadcast_dim;
Eigen::array<int64_t, D> broadcast_dim;
for (size_t i = 0; i < D; ++i) broadcast_dim[i] = 1;

int broad_cast_times = 1;
int64_t broad_cast_times = 1;
for (size_t i = 0; i < dims_ref.size(); ++i) {
if (dims_ref[i] < 0) {
dims_ref[i] = x_rank + dims_ref[i];
Expand Down Expand Up @@ -142,7 +142,7 @@ void LaunchReduceGradKernel(const Context& dev_ctx,
auto& place = *dev_ctx.eigen_device();
// *dev_ctx.eigen_device();
auto broadcast_dim =
Eigen::array<int, 1>({{static_cast<int>(input0->numel())}});
Eigen::array<int64_t, 1>({{static_cast<int64_t>(input0->numel())}});
functor(place,
&x,
&x_reduce,
Expand Down
39 changes: 30 additions & 9 deletions paddle/phi/kernels/gpu/dist_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@ struct PowFunctor {
Ty p_order_;
};

template <typename Tx,
typename Ty,
typename Tout> // Tx is high precision, Tout is low/out precision
struct PowFunctorHighPrecision {
HOSTDEVICE explicit inline PowFunctorHighPrecision(const Ty& p_order)
: p_order_(p_order) {}
HOSTDEVICE inline Tx operator()(const Tx x) const {
return static_cast<Tout>(pow(static_cast<Ty>(x), p_order_));
}
Ty p_order_;
};

template <typename T, typename Functor>
__global__ void ReduceSumWithSubtract(
const T* x, const T* y, T* out, int64_t N, Functor func) {
Expand Down Expand Up @@ -126,16 +138,17 @@ void DistKernel(const Context& dev_ctx,
DenseTensor intermediate;
const T* x_ptr = x.data<T>();
const T* y_ptr = y.data<T>();

T* o_ptr = dev_ctx.template Alloc<T>(out);
auto stream = dev_ctx.stream();

auto xdim = x.dims();
if (xdim == y.dims()) { // same shape
auto n = x.numel();
int64_t n = x.numel();

auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n);
intermediate.Resize(common::make_ddim({config.block_per_grid.x}));
T* i_ptr = dev_ctx.template Alloc<T>(&intermediate);

std::vector<int64_t> axis_dims = {static_cast<int64_t>(-1)};
std::vector<int> reduce_axis =
funcs::details::GetReduceDim(axis_dims, xdim.size(), true);
Expand Down Expand Up @@ -166,15 +179,23 @@ void DistKernel(const Context& dev_ctx,
ReduceSumWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor<T, MT>(p_order));
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<MT>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<MT>(), reduce_axis);

const DenseTensor* tmp_norm = out;
std::vector<const DenseTensor*> ins = {tmp_norm};
DenseTensor out_other;
out_other.Resize(out->dims());
dev_ctx.template Alloc<MT>(&out_other);

phi::funcs::
ReduceKernel<T, MT, kps::AddFunctor, kps::IdentityFunctor<MT>>(
dev_ctx,
intermediate,
&out_other,
kps::IdentityFunctor<MT>(),
reduce_axis);
std::vector<const DenseTensor*> ins = {&out_other};
std::vector<DenseTensor*> outs = {out};
MT p_order_ = static_cast<MT>(static_cast<MT>(1.) / p_order);

MT p_order_ = static_cast<MT>(1.f / p_order);
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, PowFunctor<T, MT>(p_order_));
dev_ctx, ins, &outs, PowFunctorHighPrecision<MT, MT, T>(p_order_));
}

} else {
Expand Down
35 changes: 29 additions & 6 deletions paddle/phi/kernels/gpu/p_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ struct AbsMaxAndMinGradFunctor {

template <typename T>
struct PNormGradFunctor {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
HOSTDEVICE explicit inline PNormGradFunctor(float porder, float eps) {
this->porder = static_cast<T>(porder - 1.);
this->eps = static_cast<T>(eps);
this->porder = static_cast<MT>(porder - 1.);
this->eps = static_cast<MT>(eps);
}

template <typename Context,
typename X,
typename Y,
Expand All @@ -59,12 +61,33 @@ struct PNormGradFunctor {
DY* dy,
const Dim& dim,
int size) {
auto x_mt = x->template cast<MT>();
auto y_mt = y->template cast<MT>();
auto dy_mt = dy->template cast<MT>();

auto norm_pow = y_mt.pow(-this->porder);
auto mask_norm_nonzero = (y_mt != static_cast<MT>(0)).template cast<MT>();

// Set to 0 where porder < 0 and x == 0
MT zero = static_cast<MT>(0);
auto mask_x_zero = (x_mt == zero).template cast<MT>();

MT is_porder_negative =
this->porder < zero ? static_cast<MT>(1) : static_cast<MT>(0);
auto invalid_mask = (mask_x_zero * is_porder_negative);
auto safe_pow =
x_mt.abs().pow(this->porder) * (static_cast<MT>(1) - invalid_mask);

dx->device(place) =
(*x).abs().pow(this->porder) * (*x).sign() * dy->broadcast(dim) *
(*y + y->constant(eps)).pow(-this->porder).broadcast(dim);
(safe_pow * x_mt.sign() * dy_mt.broadcast(dim) *
norm_pow.broadcast(dim) *
mask_norm_nonzero.broadcast(dim) // Mask out positions where norm == 0
)
.template cast<T>();
}
T porder;
T eps;

MT porder;
MT eps;
};

template <typename T, typename Context>
Expand Down
27 changes: 17 additions & 10 deletions paddle/phi/kernels/gpu/p_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,31 +124,38 @@ void PNormKernel(const Context& dev_ctx,
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, UnsignedPowFunctor<T>(1. / porder));
#else
DenseTensor out_temp;
out_temp.Resize(out_norm->dims());
dev_ctx.template Alloc<MT>(&out_temp);

if (porder == 1.0) {
// fast 1-norm
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, FabsFunctor<T>>(
dev_ctx, *in_x, out_norm, FabsFunctor<T>(), reduce_axis);
} else if (porder == 2.0) {
// fast 2-norm
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, SquareFunctor<T>>(
dev_ctx, *in_x, out_norm, SquareFunctor<T>(), reduce_axis);
phi::funcs::ReduceKernel<T, MT, kps::AddFunctor, SquareFunctor<MT>>(
dev_ctx, *in_x, &out_temp, SquareFunctor<MT>(), reduce_axis);
} else if (porder == 3.0) {
// fast 3-norm
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, FabsCubicFunctor<T>>(
dev_ctx, *in_x, out_norm, FabsCubicFunctor<T>(), reduce_axis);
phi::funcs::ReduceKernel<T, MT, kps::AddFunctor, FabsCubicFunctor<MT>>(
dev_ctx, *in_x, &out_temp, FabsCubicFunctor<MT>(), reduce_axis);
} else {
// vanilla norm
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, UnsignedPowFunctor<T>>(
dev_ctx, *in_x, out_norm, UnsignedPowFunctor<T>(porder), reduce_axis);
phi::funcs::ReduceKernel<T, MT, kps::AddFunctor, UnsignedPowFunctor<MT>>(
dev_ctx,
*in_x,
&out_temp,
UnsignedPowFunctor<MT>(porder),
reduce_axis);
}

if (porder != 1.0) {
// save computation when porder is 1.0
const DenseTensor* tmp_norm = out_norm;
std::vector<const DenseTensor*> ins = {tmp_norm};
std::vector<const DenseTensor*> ins = {&out_temp};
std::vector<DenseTensor*> outs = {out_norm};
MT p_order_ = static_cast<MT>(1.f / porder);
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, UnsignedPowFunctor<T>(1. / porder));
dev_ctx, ins, &outs, UnsignedPowFunctor<MT>(p_order_));
}
#endif
}
Expand Down