Skip to content

Commit

Permalink
add UseCudnnSoftmax
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangting2020 committed Jun 21, 2022
1 parent ce6666a commit a089601
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions paddle/phi/kernels/gpudnn/softmax_gpudnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,25 @@ inline void LaunchSoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
}
#endif

template <typename T>
bool UseCudnnSoftmax(const GPUContext& ctx, int softmax_dim, bool last_dim) {
bool cudnn_available = ctx.cudnn_handle();
if (!ctx.cudnn_handle()) {
if (std::is_same<T, phi::dtype::bfloat16>::value) {
#if CUDNN_VERSION < 8100
cudnn_available = false;
#endif
}
}
constexpr int max_dim = 512;
if (!cudnn_available || !last_dim ||
(softmax_dim <= max_dim && sizeof(T) <= 4)) {
return false;
} else {
return true;
}
}

template <typename T, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
const DenseTensor& x,
Expand Down

0 comments on commit a089601

Please sign in to comment.