-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add __launch_bounds__
to kernels based on Executor properties
#853
Conversation
__global__ void | ||
__launch_bounds__(threadsPerBlock, blocksPerSm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remember that these have different meanings for CUDA and HIP: "minimum blocks per CU" vs "minimum warps per EU" (AMD has 4 EU per CU, and NVIDIA has 1). I think we should choose the latter as our argument since it's more general and convert in the CUDA case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right it's not exactly the same semantics. I'm pretty sure CUDA has 4 EU (warp scheduler) per SM for almost all architectures except CC6.0 which has 2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the difference is that with NVIDIA, warps can move around execution units in an SM while that block is being executed (it's the block pinned to the SM), whereas with AMD warps are pinned to a specific EU. It's another layer of granularity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that so? I'll need to double check but I thought that a warp is bound to a scheduler in cuda as well.
I don't like having to duplicate the kernel and constructors that much but haven't found another way other than using pre-processor directives. Also, we'd need to compile vecgeom with limited register usage as well otherwise the compile device code would use more registers than the launch bounds specified by the ActionLauncher. |
@sethrj If we limit the number of registers for propagation in a uniform field, e.g. by specifying
|
…an Applier member type
@sethrj I think this is ready, save for default values for max_block_size but I wouldn't use any. Either we don't specify a launch bound or we use optimal bounds after profiling. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great stuff: clever but neat and simple. I've got just a couple of comments.
__launch_bounds__
kernel launch in ActionLauncher
__launch_bounds__
to kernels based on Executor properties
@esseivaju it looks like the {
"const_mem": 0,
"heap_size": 68702699520,
-"local_mem": 424,
-"max_blocks_per_cu": 4,
-"max_threads_per_block": 1024,
-"max_warps_per_eu": 4,
+"local_mem": 744,
+"max_blocks_per_cu": 8,
+"max_threads_per_block": 256,
+"max_warps_per_eu": 8,
"name": "along-step-uniform-msc-propagate",
-"num_regs": 128,
-"occupancy": 0.5,
+"num_regs": 64,
+"occupancy": 1.0,
"print_buffer_size": 0,
"threads_per_block": 256
}, The equivalent kernel in CUDA (V100) is {
"const_mem": 0,
"heap_size": 8388608,
"local_mem": 0,
"max_blocks_per_cu": 1,
"max_threads_per_block": 256,
"max_warps_per_eu": 8,
"name": "along-step-uniform-msc-propagate",
"num_regs": 184,
"occupancy": 0.125,
"print_buffer_size": 5242880,
"stack_size": 1024,
"threads_per_block": 256
} Maybe we should be using "blocks per CU" as the criteria as you originally suggested, and scale it for HIP... |
Add template parameters to
ActionLauncher
to specify__launch_bounds__
for each kernel launch. It probably make sense to add a compile-time option to enable/disable launch_bounds but we'd ideally enable it independently for each kernel since it might improve performance only for a subset. Not sure what would be the most ergonomic way, maybe a cmake variable with a list of action labels?