Skip to content

Commit

Permalink
polish code
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangting2020 committed Jun 20, 2022
1 parent a1f15cb commit a5f4a39
Showing 1 changed file with 69 additions and 43 deletions.
112 changes: 69 additions & 43 deletions paddle/phi/kernels/gpudnn/softmax_gpudnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,31 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
#endif
}

template <typename T>
void LaunchSoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& x,
const int axis,
const bool log_mode,
DenseTensor* out) {
auto* out_data = out->data<T>();
auto* x_data = x.data<T>();
const int rank = x.dims().size();

std::vector<int> tensor_dims = GetSoftmaxTensorDims(x.dims(), axis);
int64_t remaining = tensor_dims[0];
int dim = tensor_dims[1];
int64_t batch_size = std::numeric_limits<int32_t>::max() / dim;
int offset = batch_size * dim;
while (remaining > 0) {
tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
SoftmaxForwardCudnnKernel<T>(
dev_ctx, x_data, axis, rank, log_mode, tensor_dims, out_data);
x_data += offset;
out_data += offset;
remaining -= batch_size;
}
}

template <typename T>
void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
const T* out_data,
Expand Down Expand Up @@ -868,30 +893,60 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
#endif
}

template <typename T>
void LaunchSoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
const int axis,
const bool log_mode,
DenseTensor* dx) {
auto* dx_data = dx->data<T>();
auto* out_data = out.data<T>();
auto* dout_data = dout.data<T>();
int rank = out.dims().size();

std::vector<int> tensor_dims = GetSoftmaxTensorDims(out.dims(), axis);
int64_t remaining = tensor_dims[0];
int dim = tensor_dims[1];
int64_t batch_size = std::numeric_limits<int32_t>::max() / dim;
int offset = batch_size * dim;
while (remaining > 0) {
tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
SoftmaxBackwardCudnnKernel<T>(dev_ctx,
out_data,
dout_data,
axis,
rank,
log_mode,
tensor_dims,
dx_data);
out_data += offset;
dout_data += offset;
dx_data += offset;
remaining -= batch_size;
}
}

#if CUDNN_VERSION < 8100
template <>
inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
inline void LaunchSoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx,
const T* x_data,
const DenseTensor& x,
const int axis,
const int rank,
const bool log_mode,
const std::vector<int>& tensor_dims,
T* out_data) {
DenseTensor* out) {
PADDLE_THROW(errors::Unavailable(
"This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < "
"8100."));
}
template <>
inline void SoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
inline void LaunchSoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx,
const T* out_data,
const T* dout_data,
const DenseTensor& out,
const DenseTensor& dout,
const int axis,
const int rank,
const bool log_mode,
const std::vector<int>& tensor_dims,
T* dx_data) {
DenseTensor* dx) {
PADDLE_THROW(errors::Unavailable(
"This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < "
"8100."));
Expand Down Expand Up @@ -982,18 +1037,7 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
dim_log2);
}
} else {
int64_t remaining = N;
auto* x_data = x.data<T>();
int64_t batch_size = INT_MAX / dim;
int offset = batch_size * dim;
while (remaining > 0) {
tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
SoftmaxForwardCudnnKernel<T>(
dev_ctx, x_data, axis, rank, LogMode, tensor_dims, out_data);
x_data += offset;
out_data += offset;
remaining -= batch_size;
}
LaunchSoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
}
} else {
LaunchNormalSoftmaxForward<T, LogMode>(
Expand Down Expand Up @@ -1068,26 +1112,8 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
dim_log2);
}
} else {
int64_t remaining = N;
auto* out_data = out.data<T>();
auto* dout_data = dout.data<T>();
int64_t batch_size = INT_MAX / dim;
int offset = batch_size * dim;
while (remaining > 0) {
tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
SoftmaxBackwardCudnnKernel<T>(dev_ctx,
out_data,
dout_data,
axis,
rank,
LogMode,
tensor_dims,
dx_data);
out_data += offset;
dout_data += offset;
dx_data += offset;
remaining -= batch_size;
}
LaunchSoftmaxBackwardCudnnKernel<T>(
dev_ctx, out, dout, axis, LogMode, dx);
}
} else {
LaunchNormalSoftmaxBackward<T, LogMode>(
Expand Down

0 comments on commit a5f4a39

Please sign in to comment.