Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

slice large tensor for cudnn_softmax #43681

Merged
merged 2 commits into from
Jun 21, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 159 additions & 127 deletions paddle/phi/kernels/gpudnn/softmax_gpudnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -772,15 +772,12 @@ static std::vector<int> GetSoftmaxTensorDims(const phi::DDim& dims,

template <typename T>
void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& x,
const T* x_data,
const int axis,
const int rank,
const bool log_mode,
DenseTensor* out) {
auto* out_data = out->data<T>();

const int rank = x.dims().size();
std::vector<int> tensor_dims = GetSoftmaxTensorDims(x.dims(), axis);

const std::vector<int>& tensor_dims,
T* out_data) {
auto handle = dev_ctx.cudnn_handle();
GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;

Expand All @@ -795,7 +792,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
handle,
paddle::platform::CudnnDataType<T>::kOne(),
desc,
x.data<T>(),
x_data,
paddle::platform::CudnnDataType<T>::kZero(),
desc,
out_data,
Expand All @@ -812,7 +809,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
mode,
paddle::platform::CudnnDataType<T>::kOne(),
desc,
x.data<T>(),
x_data,
paddle::platform::CudnnDataType<T>::kZero(),
desc,
out_data));
Expand All @@ -821,16 +818,13 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,

template <typename T>
void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
const T* out_data,
const T* dout_data,
const int axis,
const int rank,
const bool log_mode,
DenseTensor* dx) {
auto* dx_data = dx->data<T>();

int rank = out.dims().size();
std::vector<int> tensor_dims = GetSoftmaxTensorDims(out.dims(), axis);

const std::vector<int>& tensor_dims,
T* dx_data) {
auto handle = dev_ctx.cudnn_handle();
GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;

Expand All @@ -846,9 +840,9 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
handle,
paddle::platform::CudnnDataType<T>::kOne(),
desc,
out.data<T>(),
out_data,
desc,
dout.data<T>(),
dout_data,
paddle::platform::CudnnDataType<T>::kZero(),
desc,
dx_data,
Expand All @@ -865,9 +859,9 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
mode,
paddle::platform::CudnnDataType<T>::kOne(),
desc,
out.data<T>(),
out_data,
desc,
dout.data<T>(),
dout_data,
paddle::platform::CudnnDataType<T>::kZero(),
desc,
dx_data));
Expand All @@ -878,22 +872,26 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
template <>
inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx,
const DenseTensor& x,
const T* x_data,
const int axis,
const int rank,
const bool log_mode,
DenseTensor* out) {
const std::vector<int>& tensor_dims,
T* out_data) {
PADDLE_THROW(errors::Unavailable(
"This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < "
"8100."));
}
template <>
inline void SoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
const T* out_data,
const T* dout_data,
const int axis,
const int rank,
const bool log_mode,
DenseTensor* dx) {
const std::vector<int>& tensor_dims,
T* dx_data) {
PADDLE_THROW(errors::Unavailable(
"This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < "
"8100."));
Expand Down Expand Up @@ -933,60 +931,73 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
int dim = tensor_dims[1];
int D = tensor_dims[2];

if (D == 1 && !UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
int dim_log2 = static_cast<int>(Log2Ceil(dim));
int dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
int batches_per_warp = (dim_ceil <= 32) ? 2 : 1;

// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;

int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);

// vectorization read/write
using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type;

if (dim % 4 == 0) {
SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxForward<T, T2, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
if (D == 1) {
if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
int dim_log2 = static_cast<int>(Log2Ceil(dim));
int dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
int batches_per_warp = (dim_ceil <= 32) ? 2 : 1;

// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;

int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);

// vectorization read/write
using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type;

if (dim % 4 == 0) {
SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxForward<T, T2, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else {
SwitchWarpSoftmaxForward<T, T, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
}
} else {
SwitchWarpSoftmaxForward<T, T, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
int64_t remaining = N;
auto* x_data = x.data<T>();
int64_t batch_size = INT_MAX / dim;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

INT_MAX -> std::numeric_limits<int32_t>::max()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int offset = batch_size * dim;
while (remaining > 0) {
tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
SoftmaxForwardCudnnKernel<T>(
dev_ctx, x_data, axis, rank, LogMode, tensor_dims, out_data);
x_data += offset;
out_data += offset;
remaining -= batch_size;
}
}
} else if (D > 1) {
} else {
LaunchNormalSoftmaxForward<T, LogMode>(
dev_ctx, out_data, x.data<T>(), N, dim, D);
} else {
SoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
}
}

Expand All @@ -1005,61 +1016,82 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
int dim = tensor_dims[1];
int D = tensor_dims[2];

if (D == 1 && !UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
int dim_log2 = Log2Ceil(dim);
int dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
int batches_per_warp = (dim_ceil <= 128) ? 2 : 1;

constexpr int threads_per_block = 128;

int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);

// vectorization read/write
using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type;
if (dim % 4 == 0) {
SwitchWarpSoftmaxBackward<T, T4, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxBackward<T, T2, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
if (D == 1) {
if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
int dim_log2 = Log2Ceil(dim);
int dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
int batches_per_warp = (dim_ceil <= 128) ? 2 : 1;

constexpr int threads_per_block = 128;

int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);

// vectorization read/write
using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type;
if (dim % 4 == 0) {
SwitchWarpSoftmaxBackward<T, T4, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxBackward<T, T2, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
} else {
SwitchWarpSoftmaxBackward<T, T, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
}
} else {
SwitchWarpSoftmaxBackward<T, T, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
int64_t remaining = N;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果只是支持了cudnn实现,那这些逻辑是不是直接实现在SoftmaxForward/BackwardCudnnKernel函数里面更好?避免一个函数太长了吧。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在最外层的softmax接口里去掉了这段,切片的逻辑封装了一下。

auto* out_data = out.data<T>();
auto* dout_data = dout.data<T>();
int64_t batch_size = INT_MAX / dim;
int offset = batch_size * dim;
while (remaining > 0) {
tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
SoftmaxBackwardCudnnKernel<T>(dev_ctx,
out_data,
dout_data,
axis,
rank,
LogMode,
tensor_dims,
dx_data);
out_data += offset;
dout_data += offset;
dx_data += offset;
remaining -= batch_size;
}
}
} else if (D > 1) {
} else {
LaunchNormalSoftmaxBackward<T, LogMode>(
dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, dim, D);
} else {
SoftmaxBackwardCudnnKernel<T>(dev_ctx, out, dout, axis, LogMode, dx);
}
}

Expand Down