Skip to content

Commit 8ea4da2

Browse files
bamaxwfacebook-github-bot
authored andcommitted
Replacing custom CUDA block reductions with CUB in sample_farthest_points
Summary: Removing hardcoded block reduction operation from `sample_farthest_points.cu` code, and replace it with `cub::BlockReduce` reducing complexity of the code, and letting established libraries do the thinking for us. Reviewed By: bottler Differential Revision: D38617147 fbshipit-source-id: b230029c55f05cda0aab1648d3105a8d3e92d27b
1 parent 597bc7c commit 8ea4da2

File tree

1 file changed

+26
-50
lines changed

1 file changed

+26
-50
lines changed

Diff for: pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cu

+26-50
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <math.h>
1313
#include <stdio.h>
1414
#include <stdlib.h>
15+
#include <cub/cub.cuh>
1516
#include "utils/warp_reduce.cuh"
1617

1718
template <unsigned int block_size>
@@ -25,20 +26,19 @@ __global__ void FarthestPointSamplingKernel(
2526
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> start_idxs
2627
// clang-format on
2728
) {
29+
typedef cub::BlockReduce<
30+
cub::KeyValuePair<int64_t, float>,
31+
block_size,
32+
cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY>
33+
BlockReduce;
34+
__shared__ typename BlockReduce::TempStorage temp_storage;
35+
__shared__ int64_t selected_store;
36+
2837
// Get constants
2938
const int64_t N = points.size(0);
3039
const int64_t P = points.size(1);
3140
const int64_t D = points.size(2);
3241

33-
// Create single shared memory buffer which is split and cast to different
34-
// types: dists/dists_idx are used to save the maximum distances seen by the
35-
// points processed by any one thread and the associated point indices.
36-
// These values only need to be accessed by other threads in this block which
37-
// are processing the same batch and not by other blocks.
38-
extern __shared__ char shared_buf[];
39-
float* dists = (float*)shared_buf; // block_size floats
40-
int64_t* dists_idx = (int64_t*)&dists[block_size]; // block_size int64_t
41-
4242
// Get batch index and thread index
4343
const int64_t batch_idx = blockIdx.x;
4444
const size_t tid = threadIdx.x;
@@ -82,43 +82,26 @@ __global__ void FarthestPointSamplingKernel(
8282
max_dist = (p_min_dist > max_dist) ? p_min_dist : max_dist;
8383
}
8484

85-
// After going through all points for this thread, save the max
86-
// point and idx seen by this thread. Each thread sees P/block_size points.
87-
dists[tid] = max_dist;
88-
dists_idx[tid] = max_dist_idx;
89-
// Sync to ensure all threads in the block have updated their max point.
90-
__syncthreads();
91-
92-
// Parallelized block reduction to find the max point seen by
93-
// all the threads in this block for iteration k.
94-
// Each block represents one batch element so we can use a divide/conquer
95-
// approach to find the max, syncing all threads after each step.
96-
97-
for (int s = block_size / 2; s > 0; s >>= 1) {
98-
if (tid < s) {
99-
// Compare the best point seen by two threads and update the shared
100-
// memory at the location of the first thread index with the max out
101-
// of the two threads.
102-
if (dists[tid] < dists[tid + s]) {
103-
dists[tid] = dists[tid + s];
104-
dists_idx[tid] = dists_idx[tid + s];
105-
}
106-
}
107-
__syncthreads();
108-
}
109-
110-
// TODO(nikhilar): As reduction proceeds, the number of “active” threads
111-
// decreases. When tid < 32, there should only be one warp left which could
112-
// be unrolled.
113-
114-
// The overall max after reducing will be saved
115-
// at the location of tid = 0.
116-
selected = dists_idx[0];
85+
// max_dist, max_dist_idx are now the max point and idx seen by this thread.
86+
// Now find the index corresponding to the maximum distance seen by any
87+
// thread. (This value is only on thread 0.)
88+
selected =
89+
BlockReduce(temp_storage)
90+
.Reduce(
91+
cub::KeyValuePair<int64_t, float>(max_dist_idx, max_dist),
92+
cub::ArgMax(),
93+
block_size)
94+
.key;
11795

11896
if (tid == 0) {
11997
// Write the farthest point for iteration k to global memory
12098
idxs[batch_idx][k] = selected;
99+
selected_store = selected;
121100
}
101+
102+
// Ensure `selected` in all threads equals the global maximum.
103+
__syncthreads();
104+
selected = selected_store;
122105
}
123106
}
124107

@@ -185,15 +168,8 @@ at::Tensor FarthestPointSamplingCuda(
185168
auto min_point_dist_a =
186169
min_point_dist.packed_accessor64<float, 2, at::RestrictPtrTraits>();
187170

188-
// Initialize the shared memory which will be used to store the
189-
// distance/index of the best point seen by each thread.
190-
size_t shared_mem = threads * sizeof(float) + threads * sizeof(int64_t);
191-
// TODO: using shared memory for min_point_dist gives an ~2x speed up
192-
// compared to using a global (N, P) shaped tensor, however for
193-
// larger pointclouds this may exceed the shared memory limit per block.
194-
// If a speed up is required for smaller pointclouds, then the storage
195-
// could be switched to shared memory if the required total shared memory is
196-
// within the memory limit per block.
171+
// TempStorage for the reduction uses static shared memory only.
172+
size_t shared_mem = 0;
197173

198174
// Support a case for all powers of 2 up to MAX_THREADS_PER_BLOCK possible per
199175
// block.

0 commit comments

Comments
 (0)