Skip to content

Commit

Permalink
add launch bound to limit the registers usage for volta architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
zlsh80826 committed Dec 14, 2021
1 parent 4c1e27c commit 552ccb7
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions paddle/fluid/operators/math/segment_pooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ __global__ void SegmentMeanKernel(const Index* segment_ids, const T* input,
}

template <typename T, typename Index, typename Helper, typename Pool>
__global__ void SegmentOpsKernel(const Index* segment_ids, const T* input,
T* output, Helper h, Pool pool) {
__global__ void __launch_bounds__(1024, 1)
SegmentOpsKernel(const Index* segment_ids, const T* input, T* output,
Helper h, Pool pool) {
CUDA_KERNEL_LOOP(stripe_index, h.total_stripe_count) {
Index segment_offset, dim_index_base, actual_height;
Index inner_dim_size = h.inner_dim_size;
Expand Down

1 comment on commit 552ccb7

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.