Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Fix test_segment_ops unit test failed on V100 #38113

Merged
merged 1 commit into from
Dec 17, 2021

Conversation

zlsh80826
Copy link
Collaborator

PR types

Bug fixes

PR changes

OPs

Describe

From --ptxas-options=-v, SegmentOpsKernel uses 66 registers in a block.

ptxas info    : Function properties for _ZN6paddle9operators16SegmentOpsKernelIdlNS0_13ArrangeHelperIlEENS0_7MinPoolIdEEEEvPKT0_PKT_PS9_T1_T2_
    56 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 66 registers, 425 bytes cmem[0]
ptxas info    : Compiling entry function '_ZN6paddle9operators16SegmentOpsKernelIdlNS0_13ArrangeHelperIlEENS0_7MaxPoolIdEEEEvPKT0_PKT_PS9_T1_T2_' for 'sm_70'
ptxas info    : Function properties for _ZN6paddle9operators16SegmentOpsKernelIdlNS0_13ArrangeHelperIlEENS0_7MaxPoolIdEEEEvPKT0_PKT_PS9_T1_T2_
    56 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 66 registers, 425 bytes cmem[0]
ptxas info    : Compiling entry function '_ZN6paddle9operators16SegmentOpsKernelIdlNS0_13ArrangeHelperIlEENS0_7SumPoolIdEEEEvPKT0_PKT_PS9_T1_T2_' for 'sm_70'
ptxas info    : Function properties for _ZN6paddle9operators16SegmentOpsKernelIdlNS0_13ArrangeHelperIlEENS0_7SumPoolIdEEEEvPKT0_PKT_PS9_T1_T2_
    56 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 66 registers, 425 bytes cmem[0]

Then SegmentPoolFunctor launch SegmentOpsKernel with 1024 threads per block results cudaErrorLaunchOutOfResources (error 701), because V100 has only 65536 registers per SM, there is no any block can be executed on V100 with above configuration (66*1024=67584).

There are two ways to resolve this problem:

  1. Reduce the threads per block launch configuration
  2. add __launch_bound__ to give information to nvcc compiler for reducing registers usage

I choose __launch_bound__ solution because changing gpu_launch_config may affect other ops.

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -120,8 +120,9 @@ __global__ void SegmentMeanKernel(const Index* segment_ids, const T* input,
}

template <typename T, typename Index, typename Helper, typename Pool>
__global__ void SegmentOpsKernel(const Index* segment_ids, const T* input,
T* output, Helper h, Pool pool) {
__global__ void __launch_bounds__(1024, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some other kernels in segment_pooling.cu, such as SegmentMeanKernel, share the same launch config.
Dose other kernel function which have the same launch config may cause the same problem?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dose other kernel function which have the same launch config may cause the same problem?

There might be, but currently no other kernels encounter the same problem on V100. NVCC doesn't know the runtime launch config, so it doesn't limit the registers usage . For example, this kernel can be ran with 128/256/512 threads per block, if NVCC limits the registers usage, it may reduce the performance of above configurations.

BTW, 1024 threads per block results lower performance than 128/256 threads per block from my experiences. CUDA Best Practices also says that

Between 128 and 256 threads per block is a good initial range for experimentation with different block sizes.

However, it may be a large effort to do performance benchmarks and verifications on each op used this launch config.

Copy link
Collaborator

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lanxianghit lanxianghit merged commit 18a5982 into PaddlePaddle:develop Dec 17, 2021
@zlsh80826 zlsh80826 deleted the fix_segment_ops_on_v100 branch January 2, 2022 11:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants