Skip to content

Commit

Permalink
Add launch_bounds (PaddlePaddle#47285)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wong4j authored and zlsh80826 committed Nov 23, 2022
1 parent fe1b84b commit 812ea0d
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions paddle/fluid/operators/fused/fused_dropout_act_bias.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,17 +256,19 @@ template <typename T,
int BlockSizeX,
int BlockSizeY,
int VecSize,
typename Functor>
__global__ void FusedDropoutActBiasGrad(Functor act_grad,
const T *dout,
const MaskType *mask,
const T *src,
const T *bias,
const T factor,
const int64_t rows,
const int64_t cols,
T *dx,
T *dbias) {
typename Functor,
int THREADS_PER_CTA = BlockSizeX *BlockSizeY>
__global__ __launch_bounds__(THREADS_PER_CTA) void FusedDropoutActBiasGrad(
Functor act_grad,
const T *dout,
const MaskType *mask,
const T *src,
const T *bias,
const T factor,
const int64_t rows,
const int64_t cols,
T *dx,
T *dbias) {
int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x;

using LoadT = phi::AlignedVector<T, VecSize>;
Expand Down

0 comments on commit 812ea0d

Please sign in to comment.