diff --git a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h index bdddf05fcdd54..fbaa1f193a608 100644 --- a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h +++ b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h @@ -967,6 +967,25 @@ inline void LaunchSoftmaxBackwardCudnnKernel( } #endif +template +bool UseCudnnSoftmax(const GPUContext& ctx, int softmax_dim, bool last_dim) { + bool cudnn_available = ctx.cudnn_handle(); + if (!ctx.cudnn_handle()) { + if (std::is_same::value) { +#if CUDNN_VERSION < 8100 + cudnn_available = false; +#endif + } + } + constexpr int max_dim = 512; + if (!cudnn_available || !last_dim || + (softmax_dim <= max_dim && sizeof(T) <= 4)) { + return false; + } else { + return true; + } +} + template void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, const DenseTensor& x,