-
Notifications
You must be signed in to change notification settings - Fork 23.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Target 8192 blocks instead of split to large grid for large reduction #35997
Conversation
When the number of blocks is large enough, we are already achieving blalanced SM allocation. But we still should keep the number of inputs per thread large, because thread reduce is cheap. Benchmark for Half on V100: https://github.com/zasdfgbnm/things/blob/master/2020Q2/reduction-benchmark.ipynb On large tensor, it is: 1.37ms vs 1.25ms [ghstack-poisoned]
💊 CircleCI build failures summary and remediationsAs of commit 974fe24 (more details on the Dr. CI page):
1 failure not recognized by patterns:
❄️ 5 tentatively flaky failures5 failures tentatively classified as flaky but have not triggered reruns to confirm:
|
…e reduction" When the number of blocks is large enough, we are already achieving blalanced SM allocation. But we still should keep the number of inputs per thread large, because thread reduce is cheap. Benchmark for Half on V100: https://github.com/zasdfgbnm/things/blob/master/2020Q2/reduction-benchmark.ipynb On large tensor, it is: 1.37ms vs 1.25ms [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Please update the comment in code with your commit description.
…e reduction" When the number of blocks is large enough, we are already achieving blalanced SM allocation. But we still should keep the number of inputs per thread large, because thread reduce is cheap. Benchmark for Half on V100: https://github.com/zasdfgbnm/things/blob/master/2020Q2/reduction-benchmark.ipynb On large tensor, it is: 1.37ms vs 1.25ms [ghstack-poisoned]
…e reduction" When the number of blocks is large enough, we are already achieving blalanced SM allocation. But we still should keep the number of inputs per thread large, because thread reduce is cheap. Benchmark for Half on V100: https://github.com/zasdfgbnm/things/blob/master/2020Q2/reduction-benchmark.ipynb On large tensor, it is: 1.37ms vs 1.25ms Differential Revision: [D20927533](https://our.internmc.facebook.com/intern/diff/D20927533) [ghstack-poisoned]
…e reduction" When the number of blocks is large enough, we are already achieving blalanced SM allocation. But we still should keep the number of inputs per thread large, because thread reduce is cheap. Benchmark for Half on V100: https://github.com/zasdfgbnm/things/blob/master/2020Q2/reduction-benchmark.ipynb On large tensor, it is: 1.37ms vs 1.25ms Differential Revision: [D20927533](https://our.internmc.facebook.com/intern/diff/D20927533) [ghstack-poisoned]
aten/src/ATen/native/cuda/Reduce.cuh
Outdated
@@ -789,15 +789,23 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id | |||
config.output_mult[1] = config.split_output(block_height); | |||
} | |||
|
|||
if (config.input_mult[1] != 0 && config.values_per_thread() >= 256 && num_outputs <= 4096) { | |||
constexpr int target_grid_size = 4096; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This generally looks good, but probably you can make it even less than 4096? You should be targeting full occupancy, which will come out to less than 4096?
…pytorch#35997) Summary: Pull Request resolved: pytorch#35997 When the number of blocks is large enough, we are already achieving blalanced SM allocation. But we still should keep the number of inputs per thread large, because thread reduce is cheap. Benchmark for Half on V100: https://github.com/zasdfgbnm/things/blob/master/2020Q2/reduction-benchmark.ipynb On large tensor, it is: 1.37ms vs 1.25ms Test Plan: Imported from OSS Differential Revision: D20927533 Pulled By: ngimel fbshipit-source-id: 40df52e439cc1c01cda66c6195b600f301c5e984
Stack from ghstack:
When the number of blocks is large enough, we are already achieving
blalanced SM allocation. But we still should keep the number of inputs
per thread large, because thread reduce is cheap.
Benchmark for Half on V100:
https://github.com/zasdfgbnm/things/blob/master/2020Q2/reduction-benchmark.ipynb
On large tensor, it is: 1.37ms vs 1.25ms
Differential Revision: D20927533