Skip to content

Commit

Permalink
slice large tensor for cudnn_softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangting2020 committed Jun 20, 2022
1 parent b6bc6f7 commit a1f15cb
Showing 1 changed file with 159 additions and 127 deletions.
286 changes: 159 additions & 127 deletions paddle/phi/kernels/gpudnn/softmax_gpudnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -772,15 +772,12 @@ static std::vector<int> GetSoftmaxTensorDims(const phi::DDim& dims,

template <typename T>
void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& x,
const T* x_data,
const int axis,
const int rank,
const bool log_mode,
DenseTensor* out) {
auto* out_data = out->data<T>();

const int rank = x.dims().size();
std::vector<int> tensor_dims = GetSoftmaxTensorDims(x.dims(), axis);

const std::vector<int>& tensor_dims,
T* out_data) {
auto handle = dev_ctx.cudnn_handle();
GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;

Expand All @@ -795,7 +792,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
handle,
paddle::platform::CudnnDataType<T>::kOne(),
desc,
x.data<T>(),
x_data,
paddle::platform::CudnnDataType<T>::kZero(),
desc,
out_data,
Expand All @@ -812,7 +809,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
mode,
paddle::platform::CudnnDataType<T>::kOne(),
desc,
x.data<T>(),
x_data,
paddle::platform::CudnnDataType<T>::kZero(),
desc,
out_data));
Expand All @@ -821,16 +818,13 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,

template <typename T>
void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
const T* out_data,
const T* dout_data,
const int axis,
const int rank,
const bool log_mode,
DenseTensor* dx) {
auto* dx_data = dx->data<T>();

int rank = out.dims().size();
std::vector<int> tensor_dims = GetSoftmaxTensorDims(out.dims(), axis);

const std::vector<int>& tensor_dims,
T* dx_data) {
auto handle = dev_ctx.cudnn_handle();
GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;

Expand All @@ -846,9 +840,9 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
handle,
paddle::platform::CudnnDataType<T>::kOne(),
desc,
out.data<T>(),
out_data,
desc,
dout.data<T>(),
dout_data,
paddle::platform::CudnnDataType<T>::kZero(),
desc,
dx_data,
Expand All @@ -865,9 +859,9 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
mode,
paddle::platform::CudnnDataType<T>::kOne(),
desc,
out.data<T>(),
out_data,
desc,
dout.data<T>(),
dout_data,
paddle::platform::CudnnDataType<T>::kZero(),
desc,
dx_data));
Expand All @@ -878,22 +872,26 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
template <>
inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx,
const DenseTensor& x,
const T* x_data,
const int axis,
const int rank,
const bool log_mode,
DenseTensor* out) {
const std::vector<int>& tensor_dims,
T* out_data) {
PADDLE_THROW(errors::Unavailable(
"This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < "
"8100."));
}
template <>
inline void SoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
const T* out_data,
const T* dout_data,
const int axis,
const int rank,
const bool log_mode,
DenseTensor* dx) {
const std::vector<int>& tensor_dims,
T* dx_data) {
PADDLE_THROW(errors::Unavailable(
"This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < "
"8100."));
Expand Down Expand Up @@ -933,60 +931,73 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
int dim = tensor_dims[1];
int D = tensor_dims[2];

if (D == 1 && !UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
int dim_log2 = static_cast<int>(Log2Ceil(dim));
int dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
int batches_per_warp = (dim_ceil <= 32) ? 2 : 1;

// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;

int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);

// vectorization read/write
using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type;

if (dim % 4 == 0) {
SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxForward<T, T2, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
if (D == 1) {
if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
int dim_log2 = static_cast<int>(Log2Ceil(dim));
int dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
int batches_per_warp = (dim_ceil <= 32) ? 2 : 1;

// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;

int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);

// vectorization read/write
using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type;

if (dim % 4 == 0) {
SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxForward<T, T2, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else {
SwitchWarpSoftmaxForward<T, T, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
}
} else {
SwitchWarpSoftmaxForward<T, T, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
int64_t remaining = N;
auto* x_data = x.data<T>();
int64_t batch_size = INT_MAX / dim;
int offset = batch_size * dim;
while (remaining > 0) {
tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
SoftmaxForwardCudnnKernel<T>(
dev_ctx, x_data, axis, rank, LogMode, tensor_dims, out_data);
x_data += offset;
out_data += offset;
remaining -= batch_size;
}
}
} else if (D > 1) {
} else {
LaunchNormalSoftmaxForward<T, LogMode>(
dev_ctx, out_data, x.data<T>(), N, dim, D);
} else {
SoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
}
}

Expand All @@ -1005,61 +1016,82 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
int dim = tensor_dims[1];
int D = tensor_dims[2];

if (D == 1 && !UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
int dim_log2 = Log2Ceil(dim);
int dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
int batches_per_warp = (dim_ceil <= 128) ? 2 : 1;

constexpr int threads_per_block = 128;

int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);

// vectorization read/write
using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type;
if (dim % 4 == 0) {
SwitchWarpSoftmaxBackward<T, T4, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxBackward<T, T2, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
if (D == 1) {
if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
int dim_log2 = Log2Ceil(dim);
int dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
int batches_per_warp = (dim_ceil <= 128) ? 2 : 1;

constexpr int threads_per_block = 128;

int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);

// vectorization read/write
using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type;
if (dim % 4 == 0) {
SwitchWarpSoftmaxBackward<T, T4, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxBackward<T, T2, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
} else {
SwitchWarpSoftmaxBackward<T, T, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
}
} else {
SwitchWarpSoftmaxBackward<T, T, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
int64_t remaining = N;
auto* out_data = out.data<T>();
auto* dout_data = dout.data<T>();
int64_t batch_size = INT_MAX / dim;
int offset = batch_size * dim;
while (remaining > 0) {
tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
SoftmaxBackwardCudnnKernel<T>(dev_ctx,
out_data,
dout_data,
axis,
rank,
LogMode,
tensor_dims,
dx_data);
out_data += offset;
dout_data += offset;
dx_data += offset;
remaining -= batch_size;
}
}
} else if (D > 1) {
} else {
LaunchNormalSoftmaxBackward<T, LogMode>(
dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, dim, D);
} else {
SoftmaxBackwardCudnnKernel<T>(dev_ctx, out, dout, axis, LogMode, dx);
}
}

Expand Down

0 comments on commit a1f15cb

Please sign in to comment.