12
12
#include < math.h>
13
13
#include < stdio.h>
14
14
#include < stdlib.h>
15
+ #include < cub/cub.cuh>
15
16
#include " utils/warp_reduce.cuh"
16
17
17
18
template <unsigned int block_size>
@@ -25,20 +26,19 @@ __global__ void FarthestPointSamplingKernel(
25
26
const at::PackedTensorAccessor64<int64_t , 1 , at::RestrictPtrTraits> start_idxs
26
27
// clang-format on
27
28
) {
29
+ typedef cub::BlockReduce<
30
+ cub::KeyValuePair<int64_t , float >,
31
+ block_size,
32
+ cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY>
33
+ BlockReduce;
34
+ __shared__ typename BlockReduce::TempStorage temp_storage;
35
+ __shared__ int64_t selected_store;
36
+
28
37
// Get constants
29
38
const int64_t N = points.size (0 );
30
39
const int64_t P = points.size (1 );
31
40
const int64_t D = points.size (2 );
32
41
33
- // Create single shared memory buffer which is split and cast to different
34
- // types: dists/dists_idx are used to save the maximum distances seen by the
35
- // points processed by any one thread and the associated point indices.
36
- // These values only need to be accessed by other threads in this block which
37
- // are processing the same batch and not by other blocks.
38
- extern __shared__ char shared_buf[];
39
- float * dists = (float *)shared_buf; // block_size floats
40
- int64_t * dists_idx = (int64_t *)&dists[block_size]; // block_size int64_t
41
-
42
42
// Get batch index and thread index
43
43
const int64_t batch_idx = blockIdx .x ;
44
44
const size_t tid = threadIdx .x ;
@@ -82,43 +82,26 @@ __global__ void FarthestPointSamplingKernel(
82
82
max_dist = (p_min_dist > max_dist) ? p_min_dist : max_dist;
83
83
}
84
84
85
- // After going through all points for this thread, save the max
86
- // point and idx seen by this thread. Each thread sees P/block_size points.
87
- dists[tid] = max_dist;
88
- dists_idx[tid] = max_dist_idx;
89
- // Sync to ensure all threads in the block have updated their max point.
90
- __syncthreads ();
91
-
92
- // Parallelized block reduction to find the max point seen by
93
- // all the threads in this block for iteration k.
94
- // Each block represents one batch element so we can use a divide/conquer
95
- // approach to find the max, syncing all threads after each step.
96
-
97
- for (int s = block_size / 2 ; s > 0 ; s >>= 1 ) {
98
- if (tid < s) {
99
- // Compare the best point seen by two threads and update the shared
100
- // memory at the location of the first thread index with the max out
101
- // of the two threads.
102
- if (dists[tid] < dists[tid + s]) {
103
- dists[tid] = dists[tid + s];
104
- dists_idx[tid] = dists_idx[tid + s];
105
- }
106
- }
107
- __syncthreads ();
108
- }
109
-
110
- // TODO(nikhilar): As reduction proceeds, the number of “active” threads
111
- // decreases. When tid < 32, there should only be one warp left which could
112
- // be unrolled.
113
-
114
- // The overall max after reducing will be saved
115
- // at the location of tid = 0.
116
- selected = dists_idx[0 ];
85
+ // max_dist, max_dist_idx are now the max point and idx seen by this thread.
86
+ // Now find the index corresponding to the maximum distance seen by any
87
+ // thread. (This value is only on thread 0.)
88
+ selected =
89
+ BlockReduce (temp_storage)
90
+ .Reduce (
91
+ cub::KeyValuePair<int64_t , float >(max_dist_idx, max_dist),
92
+ cub::ArgMax (),
93
+ block_size)
94
+ .key ;
117
95
118
96
if (tid == 0 ) {
119
97
// Write the farthest point for iteration k to global memory
120
98
idxs[batch_idx][k] = selected;
99
+ selected_store = selected;
121
100
}
101
+
102
+ // Ensure `selected` in all threads equals the global maximum.
103
+ __syncthreads ();
104
+ selected = selected_store;
122
105
}
123
106
}
124
107
@@ -185,15 +168,8 @@ at::Tensor FarthestPointSamplingCuda(
185
168
auto min_point_dist_a =
186
169
min_point_dist.packed_accessor64 <float , 2 , at::RestrictPtrTraits>();
187
170
188
- // Initialize the shared memory which will be used to store the
189
- // distance/index of the best point seen by each thread.
190
- size_t shared_mem = threads * sizeof (float ) + threads * sizeof (int64_t );
191
- // TODO: using shared memory for min_point_dist gives an ~2x speed up
192
- // compared to using a global (N, P) shaped tensor, however for
193
- // larger pointclouds this may exceed the shared memory limit per block.
194
- // If a speed up is required for smaller pointclouds, then the storage
195
- // could be switched to shared memory if the required total shared memory is
196
- // within the memory limit per block.
171
+ // TempStorage for the reduction uses static shared memory only.
172
+ size_t shared_mem = 0 ;
197
173
198
174
// Support a case for all powers of 2 up to MAX_THREADS_PER_BLOCK possible per
199
175
// block.
0 commit comments