diff --git a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h index ca3574de77170..ef3406fd7f668 100644 --- a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h +++ b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h @@ -772,15 +772,12 @@ static std::vector GetSoftmaxTensorDims(const phi::DDim& dims, template void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx, - const DenseTensor& x, + const T* x_data, const int axis, + const int rank, const bool log_mode, - DenseTensor* out) { - auto* out_data = out->data(); - - const int rank = x.dims().size(); - std::vector tensor_dims = GetSoftmaxTensorDims(x.dims(), axis); - + const std::vector& tensor_dims, + T* out_data) { auto handle = dev_ctx.cudnn_handle(); GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW; @@ -795,7 +792,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx, handle, paddle::platform::CudnnDataType::kOne(), desc, - x.data(), + x_data, paddle::platform::CudnnDataType::kZero(), desc, out_data, @@ -812,25 +809,47 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx, mode, paddle::platform::CudnnDataType::kOne(), desc, - x.data(), + x_data, paddle::platform::CudnnDataType::kZero(), desc, out_data)); #endif } +template +void LaunchSoftmaxForwardCudnnKernel(const GPUContext& dev_ctx, + const DenseTensor& x, + const int axis, + const bool log_mode, + DenseTensor* out) { + auto* out_data = out->data(); + auto* x_data = x.data(); + const int rank = x.dims().size(); + + std::vector tensor_dims = GetSoftmaxTensorDims(x.dims(), axis); + int64_t remaining = tensor_dims[0]; + int dim = tensor_dims[1]; + int64_t batch_size = std::numeric_limits::max() / dim; + int offset = batch_size * dim; + while (remaining > 0) { + tensor_dims[0] = std::min(remaining, batch_size); + SoftmaxForwardCudnnKernel( + dev_ctx, x_data, axis, rank, log_mode, tensor_dims, out_data); + x_data += offset; + out_data += offset; + remaining -= batch_size; + } +} + template void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx, - const DenseTensor& out, - const DenseTensor& dout, + const T* out_data, + const T* dout_data, const int axis, + const int rank, const bool log_mode, - DenseTensor* dx) { - auto* dx_data = dx->data(); - - int rank = out.dims().size(); - std::vector tensor_dims = GetSoftmaxTensorDims(out.dims(), axis); - + const std::vector& tensor_dims, + T* dx_data) { auto handle = dev_ctx.cudnn_handle(); GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW; @@ -846,9 +865,9 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx, handle, paddle::platform::CudnnDataType::kOne(), desc, - out.data(), + out_data, desc, - dout.data(), + dout_data, paddle::platform::CudnnDataType::kZero(), desc, dx_data, @@ -865,18 +884,52 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx, mode, paddle::platform::CudnnDataType::kOne(), desc, - out.data(), + out_data, desc, - dout.data(), + dout_data, paddle::platform::CudnnDataType::kZero(), desc, dx_data)); #endif } +template +void LaunchSoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx, + const DenseTensor& out, + const DenseTensor& dout, + const int axis, + const bool log_mode, + DenseTensor* dx) { + auto* dx_data = dx->data(); + auto* out_data = out.data(); + auto* dout_data = dout.data(); + int rank = out.dims().size(); + + std::vector tensor_dims = GetSoftmaxTensorDims(out.dims(), axis); + int64_t remaining = tensor_dims[0]; + int dim = tensor_dims[1]; + int64_t batch_size = std::numeric_limits::max() / dim; + int offset = batch_size * dim; + while (remaining > 0) { + tensor_dims[0] = std::min(remaining, batch_size); + SoftmaxBackwardCudnnKernel(dev_ctx, + out_data, + dout_data, + axis, + rank, + log_mode, + tensor_dims, + dx_data); + out_data += offset; + dout_data += offset; + dx_data += offset; + remaining -= batch_size; + } +} + #if CUDNN_VERSION < 8100 template <> -inline void SoftmaxForwardCudnnKernel( +inline void LaunchSoftmaxForwardCudnnKernel( const GPUContext& dev_ctx, const DenseTensor& x, const int axis, @@ -887,7 +940,7 @@ inline void SoftmaxForwardCudnnKernel( "8100.")); } template <> -inline void SoftmaxBackwardCudnnKernel( +inline void LaunchSoftmaxBackwardCudnnKernel( const GPUContext& dev_ctx, const DenseTensor& out, const DenseTensor& dout, @@ -933,60 +986,62 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, int dim = tensor_dims[1]; int D = tensor_dims[2]; - if (D == 1 && !UseCudnnSoftmax(dev_ctx, dim, true)) { - int dim_log2 = static_cast(Log2Ceil(dim)); - int dim_ceil = 1 << dim_log2; - int warp_size = (dim_ceil < 32) ? dim_ceil : 32; - int batches_per_warp = (dim_ceil <= 32) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (N + batches_per_block - 1) / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - - // vectorization read/write - using T4 = typename VecT4::Type; - using T2 = typename VecT2::Type; - - if (dim % 4 == 0) { - SwitchWarpSoftmaxForward(blocks, - threads, - dev_ctx, - out_data, - x.data(), - N, - dim, - dim, - dim_log2); - } else if (dim % 2 == 0) { - SwitchWarpSoftmaxForward(blocks, - threads, - dev_ctx, - out_data, - x.data(), - N, - dim, - dim, - dim_log2); + if (D == 1) { + if (!UseCudnnSoftmax(dev_ctx, dim, true)) { + int dim_log2 = static_cast(Log2Ceil(dim)); + int dim_ceil = 1 << dim_log2; + int warp_size = (dim_ceil < 32) ? dim_ceil : 32; + int batches_per_warp = (dim_ceil <= 32) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (N + batches_per_block - 1) / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + + // vectorization read/write + using T4 = typename VecT4::Type; + using T2 = typename VecT2::Type; + + if (dim % 4 == 0) { + SwitchWarpSoftmaxForward(blocks, + threads, + dev_ctx, + out_data, + x.data(), + N, + dim, + dim, + dim_log2); + } else if (dim % 2 == 0) { + SwitchWarpSoftmaxForward(blocks, + threads, + dev_ctx, + out_data, + x.data(), + N, + dim, + dim, + dim_log2); + } else { + SwitchWarpSoftmaxForward(blocks, + threads, + dev_ctx, + out_data, + x.data(), + N, + dim, + dim, + dim_log2); + } } else { - SwitchWarpSoftmaxForward(blocks, - threads, - dev_ctx, - out_data, - x.data(), - N, - dim, - dim, - dim_log2); + LaunchSoftmaxForwardCudnnKernel(dev_ctx, x, axis, LogMode, out); } - } else if (D > 1) { + } else { LaunchNormalSoftmaxForward( dev_ctx, out_data, x.data(), N, dim, D); - } else { - SoftmaxForwardCudnnKernel(dev_ctx, x, axis, LogMode, out); } } @@ -1005,61 +1060,64 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx, int dim = tensor_dims[1]; int D = tensor_dims[2]; - if (D == 1 && !UseCudnnSoftmax(dev_ctx, dim, true)) { - int dim_log2 = Log2Ceil(dim); - int dim_ceil = 1 << dim_log2; - int warp_size = (dim_ceil < 32) ? dim_ceil : 32; - int batches_per_warp = (dim_ceil <= 128) ? 2 : 1; - - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (N + batches_per_block - 1) / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - - // vectorization read/write - using T4 = typename VecT4::Type; - using T2 = typename VecT2::Type; - if (dim % 4 == 0) { - SwitchWarpSoftmaxBackward(blocks, - threads, - dev_ctx, - dx_data, - dout.data(), - out.data(), - N, - dim, - dim, - dim_log2); - } else if (dim % 2 == 0) { - SwitchWarpSoftmaxBackward(blocks, - threads, - dev_ctx, - dx_data, - dout.data(), - out.data(), - N, - dim, - dim, - dim_log2); + if (D == 1) { + if (!UseCudnnSoftmax(dev_ctx, dim, true)) { + int dim_log2 = Log2Ceil(dim); + int dim_ceil = 1 << dim_log2; + int warp_size = (dim_ceil < 32) ? dim_ceil : 32; + int batches_per_warp = (dim_ceil <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (N + batches_per_block - 1) / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + + // vectorization read/write + using T4 = typename VecT4::Type; + using T2 = typename VecT2::Type; + if (dim % 4 == 0) { + SwitchWarpSoftmaxBackward(blocks, + threads, + dev_ctx, + dx_data, + dout.data(), + out.data(), + N, + dim, + dim, + dim_log2); + } else if (dim % 2 == 0) { + SwitchWarpSoftmaxBackward(blocks, + threads, + dev_ctx, + dx_data, + dout.data(), + out.data(), + N, + dim, + dim, + dim_log2); + } else { + SwitchWarpSoftmaxBackward(blocks, + threads, + dev_ctx, + dx_data, + dout.data(), + out.data(), + N, + dim, + dim, + dim_log2); + } } else { - SwitchWarpSoftmaxBackward(blocks, - threads, - dev_ctx, - dx_data, - dout.data(), - out.data(), - N, - dim, - dim, - dim_log2); + LaunchSoftmaxBackwardCudnnKernel( + dev_ctx, out, dout, axis, LogMode, dx); } - } else if (D > 1) { + } else { LaunchNormalSoftmaxBackward( dev_ctx, dx_data, dout.data(), out.data(), N, dim, D); - } else { - SoftmaxBackwardCudnnKernel(dev_ctx, out, dout, axis, LogMode, dx); } }