diff --git a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h index a93151fe8e653..ef3406fd7f668 100644 --- a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h +++ b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h @@ -816,6 +816,31 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx, #endif } +template +void LaunchSoftmaxForwardCudnnKernel(const GPUContext& dev_ctx, + const DenseTensor& x, + const int axis, + const bool log_mode, + DenseTensor* out) { + auto* out_data = out->data(); + auto* x_data = x.data(); + const int rank = x.dims().size(); + + std::vector tensor_dims = GetSoftmaxTensorDims(x.dims(), axis); + int64_t remaining = tensor_dims[0]; + int dim = tensor_dims[1]; + int64_t batch_size = std::numeric_limits::max() / dim; + int offset = batch_size * dim; + while (remaining > 0) { + tensor_dims[0] = std::min(remaining, batch_size); + SoftmaxForwardCudnnKernel( + dev_ctx, x_data, axis, rank, log_mode, tensor_dims, out_data); + x_data += offset; + out_data += offset; + remaining -= batch_size; + } +} + template void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx, const T* out_data, @@ -868,30 +893,60 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx, #endif } +template +void LaunchSoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx, + const DenseTensor& out, + const DenseTensor& dout, + const int axis, + const bool log_mode, + DenseTensor* dx) { + auto* dx_data = dx->data(); + auto* out_data = out.data(); + auto* dout_data = dout.data(); + int rank = out.dims().size(); + + std::vector tensor_dims = GetSoftmaxTensorDims(out.dims(), axis); + int64_t remaining = tensor_dims[0]; + int dim = tensor_dims[1]; + int64_t batch_size = std::numeric_limits::max() / dim; + int offset = batch_size * dim; + while (remaining > 0) { + tensor_dims[0] = std::min(remaining, batch_size); + SoftmaxBackwardCudnnKernel(dev_ctx, + out_data, + dout_data, + axis, + rank, + log_mode, + tensor_dims, + dx_data); + out_data += offset; + dout_data += offset; + dx_data += offset; + remaining -= batch_size; + } +} + #if CUDNN_VERSION < 8100 template <> -inline void SoftmaxForwardCudnnKernel( +inline void LaunchSoftmaxForwardCudnnKernel( const GPUContext& dev_ctx, - const T* x_data, + const DenseTensor& x, const int axis, - const int rank, const bool log_mode, - const std::vector& tensor_dims, - T* out_data) { + DenseTensor* out) { PADDLE_THROW(errors::Unavailable( "This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < " "8100.")); } template <> -inline void SoftmaxBackwardCudnnKernel( +inline void LaunchSoftmaxBackwardCudnnKernel( const GPUContext& dev_ctx, - const T* out_data, - const T* dout_data, + const DenseTensor& out, + const DenseTensor& dout, const int axis, - const int rank, const bool log_mode, - const std::vector& tensor_dims, - T* dx_data) { + DenseTensor* dx) { PADDLE_THROW(errors::Unavailable( "This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < " "8100.")); @@ -982,18 +1037,7 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, dim_log2); } } else { - int64_t remaining = N; - auto* x_data = x.data(); - int64_t batch_size = INT_MAX / dim; - int offset = batch_size * dim; - while (remaining > 0) { - tensor_dims[0] = std::min(remaining, batch_size); - SoftmaxForwardCudnnKernel( - dev_ctx, x_data, axis, rank, LogMode, tensor_dims, out_data); - x_data += offset; - out_data += offset; - remaining -= batch_size; - } + LaunchSoftmaxForwardCudnnKernel(dev_ctx, x, axis, LogMode, out); } } else { LaunchNormalSoftmaxForward( @@ -1068,26 +1112,8 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx, dim_log2); } } else { - int64_t remaining = N; - auto* out_data = out.data(); - auto* dout_data = dout.data(); - int64_t batch_size = INT_MAX / dim; - int offset = batch_size * dim; - while (remaining > 0) { - tensor_dims[0] = std::min(remaining, batch_size); - SoftmaxBackwardCudnnKernel(dev_ctx, - out_data, - dout_data, - axis, - rank, - LogMode, - tensor_dims, - dx_data); - out_data += offset; - dout_data += offset; - dx_data += offset; - remaining -= batch_size; - } + LaunchSoftmaxBackwardCudnnKernel( + dev_ctx, out, dout, axis, LogMode, dx); } } else { LaunchNormalSoftmaxBackward(