Skip to content

Commit

Permalink
add launch bound to limit the registers usage for volta architecture (#…
Browse files Browse the repository at this point in the history
…38113)

From --ptxas-options=-v, SegmentOpsKernel uses 66 registers in a block.
There are two ways to resolve this problem:
    Reduce the threads per block launch configuration
    add __launch_bound__ to give information to nvcc compiler for reducing registers usage
this PR chooses __launch_bound__ solution because changing gpu_launch_config may affect other ops.
  • Loading branch information
zlsh80826 authored Dec 17, 2021
1 parent 76eb371 commit 18a5982
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions paddle/fluid/operators/math/segment_pooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ __global__ void SegmentMeanKernel(const Index* segment_ids, const T* input,
}

template <typename T, typename Index, typename Helper, typename Pool>
__global__ void SegmentOpsKernel(const Index* segment_ids, const T* input,
T* output, Helper h, Pool pool) {
__global__ void __launch_bounds__(1024, 1)
SegmentOpsKernel(const Index* segment_ids, const T* input, T* output,
Helper h, Pool pool) {
CUDA_KERNEL_LOOP(stripe_index, h.total_stripe_count) {
Index segment_offset, dim_index_base, actual_height;
Index inner_dim_size = h.inner_dim_size;
Expand Down

0 comments on commit 18a5982

Please sign in to comment.