@@ -36,7 +36,8 @@ __global__ void KNearestNeighborKernelV0(
36
36
const size_t P1,
37
37
const size_t P2,
38
38
const size_t D,
39
- const size_t K) {
39
+ const size_t K,
40
+ const size_t norm) {
40
41
// Store both dists and indices for knn in global memory.
41
42
const int64_t chunks_per_cloud = (1 + (P1 - 1 ) / blockDim .x );
42
43
const int64_t chunks_to_do = N * chunks_per_cloud;
@@ -56,7 +57,8 @@ __global__ void KNearestNeighborKernelV0(
56
57
scalar_t coord1 = points1[n * P1 * D + p1 * D + d];
57
58
scalar_t coord2 = points2[n * P2 * D + p2 * D + d];
58
59
scalar_t diff = coord1 - coord2;
59
- dist += diff * diff;
60
+ scalar_t norm_diff = (norm == 2 ) ? (diff * diff) : abs (diff);
61
+ dist += norm_diff;
60
62
}
61
63
mink.add (dist, p2);
62
64
}
@@ -74,7 +76,8 @@ __global__ void KNearestNeighborKernelV1(
74
76
const size_t N,
75
77
const size_t P1,
76
78
const size_t P2,
77
- const size_t K) {
79
+ const size_t K,
80
+ const size_t norm) {
78
81
// Same idea as the previous version, but hoist D into a template argument
79
82
// so we can cache the current point in a thread-local array. We still store
80
83
// the current best K dists and indices in global memory, so this should work
@@ -99,7 +102,8 @@ __global__ void KNearestNeighborKernelV1(
99
102
scalar_t dist = 0 ;
100
103
for (int d = 0 ; d < D; ++d) {
101
104
scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d];
102
- dist += diff * diff;
105
+ scalar_t norm_diff = (norm == 2 ) ? (diff * diff) : abs (diff);
106
+ dist += norm_diff;
103
107
}
104
108
mink.add (dist, p2);
105
109
}
@@ -121,10 +125,11 @@ struct KNearestNeighborV1Functor {
121
125
const size_t N,
122
126
const size_t P1,
123
127
const size_t P2,
124
- const size_t K) {
128
+ const size_t K,
129
+ const size_t norm) {
125
130
cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
126
131
KNearestNeighborKernelV1<scalar_t , D><<<blocks, threads, 0 , stream>>> (
127
- points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K);
132
+ points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K, norm );
128
133
}
129
134
};
130
135
@@ -138,7 +143,8 @@ __global__ void KNearestNeighborKernelV2(
138
143
int64_t * __restrict__ idxs,
139
144
const int64_t N,
140
145
const int64_t P1,
141
- const int64_t P2) {
146
+ const int64_t P2,
147
+ const size_t norm) {
142
148
// Same general implementation as V2, but also hoist K into a template arg.
143
149
scalar_t cur_point[D];
144
150
scalar_t min_dists[K];
@@ -161,7 +167,8 @@ __global__ void KNearestNeighborKernelV2(
161
167
for (int d = 0 ; d < D; ++d) {
162
168
int offset = n * P2 * D + p2 * D + d;
163
169
scalar_t diff = cur_point[d] - points2[offset];
164
- dist += diff * diff;
170
+ scalar_t norm_diff = (norm == 2 ) ? (diff * diff) : abs (diff);
171
+ dist += norm_diff;
165
172
}
166
173
mink.add (dist, p2);
167
174
}
@@ -186,10 +193,11 @@ struct KNearestNeighborKernelV2Functor {
186
193
int64_t * __restrict__ idxs,
187
194
const int64_t N,
188
195
const int64_t P1,
189
- const int64_t P2) {
196
+ const int64_t P2,
197
+ const size_t norm) {
190
198
cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
191
199
KNearestNeighborKernelV2<scalar_t , D, K><<<blocks, threads, 0 , stream>>> (
192
- points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
200
+ points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm );
193
201
}
194
202
};
195
203
@@ -203,7 +211,8 @@ __global__ void KNearestNeighborKernelV3(
203
211
int64_t * __restrict__ idxs,
204
212
const size_t N,
205
213
const size_t P1,
206
- const size_t P2) {
214
+ const size_t P2,
215
+ const size_t norm) {
207
216
// Same idea as V2, but use register indexing for thread-local arrays.
208
217
// Enabling sorting for this version leads to huge slowdowns; I suspect
209
218
// that it forces min_dists into local memory rather than registers.
@@ -229,7 +238,8 @@ __global__ void KNearestNeighborKernelV3(
229
238
for (int d = 0 ; d < D; ++d) {
230
239
int offset = n * P2 * D + p2 * D + d;
231
240
scalar_t diff = cur_point[d] - points2[offset];
232
- dist += diff * diff;
241
+ scalar_t norm_diff = (norm == 2 ) ? (diff * diff) : abs (diff);
242
+ dist += norm_diff;
233
243
}
234
244
mink.add (dist, p2);
235
245
}
@@ -254,10 +264,11 @@ struct KNearestNeighborKernelV3Functor {
254
264
int64_t * __restrict__ idxs,
255
265
const size_t N,
256
266
const size_t P1,
257
- const size_t P2) {
267
+ const size_t P2,
268
+ const size_t norm) {
258
269
cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
259
270
KNearestNeighborKernelV3<scalar_t , D, K><<<blocks, threads, 0 , stream>>> (
260
- points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
271
+ points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm );
261
272
}
262
273
};
263
274
@@ -305,7 +316,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
305
316
const at::Tensor& p2,
306
317
const at::Tensor& lengths1,
307
318
const at::Tensor& lengths2,
308
- int K,
319
+ const int norm,
320
+ const int K,
309
321
int version) {
310
322
// Check inputs are on the same device
311
323
at::TensorArg p1_t {p1, " p1" , 1 }, p2_t {p2, " p2" , 2 },
@@ -324,6 +336,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
324
336
const auto D = p2.size (2 );
325
337
const int64_t K_64 = K;
326
338
339
+ TORCH_CHECK ((norm == 1 ) || (norm == 2 ), " Norm must be 1 or 2." );
340
+
327
341
TORCH_CHECK (p2.size (2 ) == D, " Point sets must have the same last dimension" );
328
342
auto long_dtype = lengths1.options ().dtype (at::kLong );
329
343
auto idxs = at::zeros ({N, P1, K}, long_dtype);
@@ -366,7 +380,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
366
380
P1,
367
381
P2,
368
382
D,
369
- K);
383
+ K,
384
+ norm);
370
385
}));
371
386
} else if (version == 1 ) {
372
387
AT_DISPATCH_FLOATING_TYPES (p1.scalar_type (), " knn_kernel_cuda" , ([&] {
@@ -387,7 +402,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
387
402
N,
388
403
P1,
389
404
P2,
390
- K);
405
+ K,
406
+ norm);
391
407
}));
392
408
} else if (version == 2 ) {
393
409
AT_DISPATCH_FLOATING_TYPES (p1.scalar_type (), " knn_kernel_cuda" , ([&] {
@@ -410,7 +426,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
410
426
idxs.data_ptr <int64_t >(),
411
427
N,
412
428
P1,
413
- P2);
429
+ P2,
430
+ norm);
414
431
}));
415
432
} else if (version == 3 ) {
416
433
AT_DISPATCH_FLOATING_TYPES (p1.scalar_type (), " knn_kernel_cuda" , ([&] {
@@ -433,7 +450,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
433
450
idxs.data_ptr <int64_t >(),
434
451
N,
435
452
P1,
436
- P2);
453
+ P2,
454
+ norm);
437
455
}));
438
456
}
439
457
AT_CUDA_CHECK (cudaGetLastError ());
@@ -459,7 +477,8 @@ __global__ void KNearestNeighborBackwardKernel(
459
477
const size_t P1,
460
478
const size_t P2,
461
479
const size_t K,
462
- const size_t D) {
480
+ const size_t D,
481
+ const size_t norm) {
463
482
const size_t tid = blockIdx .x * blockDim .x + threadIdx .x ;
464
483
const size_t stride = gridDim .x * blockDim .x ;
465
484
@@ -481,8 +500,17 @@ __global__ void KNearestNeighborBackwardKernel(
481
500
if (p2_idx == -1 ) {
482
501
continue ;
483
502
}
484
- const float diff = 2.0 * grad_dist *
485
- (p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
503
+ float diff = 0.0 ;
504
+ if (norm == 1 ) {
505
+ float sign =
506
+ (p1[n * P1 * D + p1_idx * D + d] > p2[n * P2 * D + p2_idx * D + d])
507
+ ? 1.0
508
+ : -1.0 ;
509
+ diff = grad_dist * sign;
510
+ } else { // norm is 2
511
+ diff = 2.0 * grad_dist *
512
+ (p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
513
+ }
486
514
atomicAdd (grad_p1 + n * P1 * D + p1_idx * D + d, diff);
487
515
atomicAdd (grad_p2 + n * P2 * D + p2_idx * D + d, -1 .0f * diff);
488
516
}
@@ -495,6 +523,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
495
523
const at::Tensor& lengths1,
496
524
const at::Tensor& lengths2,
497
525
const at::Tensor& idxs,
526
+ int norm,
498
527
const at::Tensor& grad_dists) {
499
528
// Check inputs are on the same device
500
529
at::TensorArg p1_t {p1, " p1" , 1 }, p2_t {p2, " p2" , 2 },
@@ -547,7 +576,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
547
576
P1,
548
577
P2,
549
578
K,
550
- D);
579
+ D,
580
+ norm);
551
581
552
582
AT_CUDA_CHECK (cudaGetLastError ());
553
583
return std::make_tuple (grad_p1, grad_p2);
0 commit comments