Skip to content

Commit 67fff95

Browse files
gkioxarifacebook-github-bot
authored andcommitted
add L1 support for KNN & Chamfer
Summary: Added L1 norm for KNN and chamfer op * The norm is now specified with a variable `norm` which can only be 1 or 2 Reviewed By: bottler Differential Revision: D35419637 fbshipit-source-id: 77813fec650b30c28342af90d5ed02c89133e136
1 parent 4b94649 commit 67fff95

File tree

8 files changed

+266
-130
lines changed

8 files changed

+266
-130
lines changed

pytorch3d/csrc/knn/knn.cu

+53-23
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ __global__ void KNearestNeighborKernelV0(
3636
const size_t P1,
3737
const size_t P2,
3838
const size_t D,
39-
const size_t K) {
39+
const size_t K,
40+
const size_t norm) {
4041
// Store both dists and indices for knn in global memory.
4142
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
4243
const int64_t chunks_to_do = N * chunks_per_cloud;
@@ -56,7 +57,8 @@ __global__ void KNearestNeighborKernelV0(
5657
scalar_t coord1 = points1[n * P1 * D + p1 * D + d];
5758
scalar_t coord2 = points2[n * P2 * D + p2 * D + d];
5859
scalar_t diff = coord1 - coord2;
59-
dist += diff * diff;
60+
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
61+
dist += norm_diff;
6062
}
6163
mink.add(dist, p2);
6264
}
@@ -74,7 +76,8 @@ __global__ void KNearestNeighborKernelV1(
7476
const size_t N,
7577
const size_t P1,
7678
const size_t P2,
77-
const size_t K) {
79+
const size_t K,
80+
const size_t norm) {
7881
// Same idea as the previous version, but hoist D into a template argument
7982
// so we can cache the current point in a thread-local array. We still store
8083
// the current best K dists and indices in global memory, so this should work
@@ -99,7 +102,8 @@ __global__ void KNearestNeighborKernelV1(
99102
scalar_t dist = 0;
100103
for (int d = 0; d < D; ++d) {
101104
scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d];
102-
dist += diff * diff;
105+
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
106+
dist += norm_diff;
103107
}
104108
mink.add(dist, p2);
105109
}
@@ -121,10 +125,11 @@ struct KNearestNeighborV1Functor {
121125
const size_t N,
122126
const size_t P1,
123127
const size_t P2,
124-
const size_t K) {
128+
const size_t K,
129+
const size_t norm) {
125130
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
126131
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads, 0, stream>>>(
127-
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K);
132+
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K, norm);
128133
}
129134
};
130135

@@ -138,7 +143,8 @@ __global__ void KNearestNeighborKernelV2(
138143
int64_t* __restrict__ idxs,
139144
const int64_t N,
140145
const int64_t P1,
141-
const int64_t P2) {
146+
const int64_t P2,
147+
const size_t norm) {
142148
// Same general implementation as V2, but also hoist K into a template arg.
143149
scalar_t cur_point[D];
144150
scalar_t min_dists[K];
@@ -161,7 +167,8 @@ __global__ void KNearestNeighborKernelV2(
161167
for (int d = 0; d < D; ++d) {
162168
int offset = n * P2 * D + p2 * D + d;
163169
scalar_t diff = cur_point[d] - points2[offset];
164-
dist += diff * diff;
170+
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
171+
dist += norm_diff;
165172
}
166173
mink.add(dist, p2);
167174
}
@@ -186,10 +193,11 @@ struct KNearestNeighborKernelV2Functor {
186193
int64_t* __restrict__ idxs,
187194
const int64_t N,
188195
const int64_t P1,
189-
const int64_t P2) {
196+
const int64_t P2,
197+
const size_t norm) {
190198
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
191199
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
192-
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
200+
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
193201
}
194202
};
195203

@@ -203,7 +211,8 @@ __global__ void KNearestNeighborKernelV3(
203211
int64_t* __restrict__ idxs,
204212
const size_t N,
205213
const size_t P1,
206-
const size_t P2) {
214+
const size_t P2,
215+
const size_t norm) {
207216
// Same idea as V2, but use register indexing for thread-local arrays.
208217
// Enabling sorting for this version leads to huge slowdowns; I suspect
209218
// that it forces min_dists into local memory rather than registers.
@@ -229,7 +238,8 @@ __global__ void KNearestNeighborKernelV3(
229238
for (int d = 0; d < D; ++d) {
230239
int offset = n * P2 * D + p2 * D + d;
231240
scalar_t diff = cur_point[d] - points2[offset];
232-
dist += diff * diff;
241+
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
242+
dist += norm_diff;
233243
}
234244
mink.add(dist, p2);
235245
}
@@ -254,10 +264,11 @@ struct KNearestNeighborKernelV3Functor {
254264
int64_t* __restrict__ idxs,
255265
const size_t N,
256266
const size_t P1,
257-
const size_t P2) {
267+
const size_t P2,
268+
const size_t norm) {
258269
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
259270
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
260-
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
271+
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
261272
}
262273
};
263274

@@ -305,7 +316,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
305316
const at::Tensor& p2,
306317
const at::Tensor& lengths1,
307318
const at::Tensor& lengths2,
308-
int K,
319+
const int norm,
320+
const int K,
309321
int version) {
310322
// Check inputs are on the same device
311323
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
@@ -324,6 +336,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
324336
const auto D = p2.size(2);
325337
const int64_t K_64 = K;
326338

339+
TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2.");
340+
327341
TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension");
328342
auto long_dtype = lengths1.options().dtype(at::kLong);
329343
auto idxs = at::zeros({N, P1, K}, long_dtype);
@@ -366,7 +380,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
366380
P1,
367381
P2,
368382
D,
369-
K);
383+
K,
384+
norm);
370385
}));
371386
} else if (version == 1) {
372387
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
@@ -387,7 +402,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
387402
N,
388403
P1,
389404
P2,
390-
K);
405+
K,
406+
norm);
391407
}));
392408
} else if (version == 2) {
393409
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
@@ -410,7 +426,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
410426
idxs.data_ptr<int64_t>(),
411427
N,
412428
P1,
413-
P2);
429+
P2,
430+
norm);
414431
}));
415432
} else if (version == 3) {
416433
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
@@ -433,7 +450,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
433450
idxs.data_ptr<int64_t>(),
434451
N,
435452
P1,
436-
P2);
453+
P2,
454+
norm);
437455
}));
438456
}
439457
AT_CUDA_CHECK(cudaGetLastError());
@@ -459,7 +477,8 @@ __global__ void KNearestNeighborBackwardKernel(
459477
const size_t P1,
460478
const size_t P2,
461479
const size_t K,
462-
const size_t D) {
480+
const size_t D,
481+
const size_t norm) {
463482
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
464483
const size_t stride = gridDim.x * blockDim.x;
465484

@@ -481,8 +500,17 @@ __global__ void KNearestNeighborBackwardKernel(
481500
if (p2_idx == -1) {
482501
continue;
483502
}
484-
const float diff = 2.0 * grad_dist *
485-
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
503+
float diff = 0.0;
504+
if (norm == 1) {
505+
float sign =
506+
(p1[n * P1 * D + p1_idx * D + d] > p2[n * P2 * D + p2_idx * D + d])
507+
? 1.0
508+
: -1.0;
509+
diff = grad_dist * sign;
510+
} else { // norm is 2
511+
diff = 2.0 * grad_dist *
512+
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
513+
}
486514
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
487515
atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff);
488516
}
@@ -495,6 +523,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
495523
const at::Tensor& lengths1,
496524
const at::Tensor& lengths2,
497525
const at::Tensor& idxs,
526+
int norm,
498527
const at::Tensor& grad_dists) {
499528
// Check inputs are on the same device
500529
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
@@ -547,7 +576,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
547576
P1,
548577
P2,
549578
K,
550-
D);
579+
D,
580+
norm);
551581

552582
AT_CUDA_CHECK(cudaGetLastError());
553583
return std::make_tuple(grad_p1, grad_p2);

pytorch3d/csrc/knn/knn.h

+18-9
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
// containing P2 points of dimension D.
2222
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
2323
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
24+
// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
2425
// K: int giving the number of nearest points to return.
2526
// version: Integer telling which implementation to use.
2627
//
@@ -41,35 +42,39 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
4142
const at::Tensor& p2,
4243
const at::Tensor& lengths1,
4344
const at::Tensor& lengths2,
44-
int K);
45+
const int norm,
46+
const int K);
4547

4648
// CUDA implementation
4749
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
4850
const at::Tensor& p1,
4951
const at::Tensor& p2,
5052
const at::Tensor& lengths1,
5153
const at::Tensor& lengths2,
52-
int K,
53-
int version);
54+
const int norm,
55+
const int K,
56+
const int version);
5457

5558
// Implementation which is exposed.
5659
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
5760
const at::Tensor& p1,
5861
const at::Tensor& p2,
5962
const at::Tensor& lengths1,
6063
const at::Tensor& lengths2,
61-
int K,
62-
int version) {
64+
const int norm,
65+
const int K,
66+
const int version) {
6367
if (p1.is_cuda() || p2.is_cuda()) {
6468
#ifdef WITH_CUDA
6569
CHECK_CUDA(p1);
6670
CHECK_CUDA(p2);
67-
return KNearestNeighborIdxCuda(p1, p2, lengths1, lengths2, K, version);
71+
return KNearestNeighborIdxCuda(
72+
p1, p2, lengths1, lengths2, norm, K, version);
6873
#else
6974
AT_ERROR("Not compiled with GPU support.");
7075
#endif
7176
}
72-
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K);
77+
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K);
7378
}
7479

7580
// Compute gradients with respect to p1 and p2
@@ -86,6 +91,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
8691
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
8792
// It is padded with zeros so that it can be used easily in a later
8893
// gather() operation. This is computed from the forward pass.
94+
// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
8995
// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input
9096
// gradients.
9197
//
@@ -102,6 +108,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
102108
const at::Tensor& lengths1,
103109
const at::Tensor& lengths2,
104110
const at::Tensor& idxs,
111+
const int norm,
105112
const at::Tensor& grad_dists);
106113

107114
// CUDA implementation
@@ -111,6 +118,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
111118
const at::Tensor& lengths1,
112119
const at::Tensor& lengths2,
113120
const at::Tensor& idxs,
121+
const int norm,
114122
const at::Tensor& grad_dists);
115123

116124
// Implementation which is exposed.
@@ -120,19 +128,20 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
120128
const at::Tensor& lengths1,
121129
const at::Tensor& lengths2,
122130
const at::Tensor& idxs,
131+
const int norm,
123132
const at::Tensor& grad_dists) {
124133
if (p1.is_cuda() || p2.is_cuda()) {
125134
#ifdef WITH_CUDA
126135
CHECK_CUDA(p1);
127136
CHECK_CUDA(p2);
128137
return KNearestNeighborBackwardCuda(
129-
p1, p2, lengths1, lengths2, idxs, grad_dists);
138+
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
130139
#else
131140
AT_ERROR("Not compiled with GPU support.");
132141
#endif
133142
}
134143
return KNearestNeighborBackwardCpu(
135-
p1, p2, lengths1, lengths2, idxs, grad_dists);
144+
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
136145
}
137146

138147
// Utility to check whether a KNN version can be used.

pytorch3d/csrc/knn/knn_cpu.cpp

+16-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
1515
const at::Tensor& p2,
1616
const at::Tensor& lengths1,
1717
const at::Tensor& lengths2,
18-
int K) {
18+
const int norm,
19+
const int K) {
1920
const int N = p1.size(0);
2021
const int P1 = p1.size(1);
2122
const int D = p1.size(2);
@@ -41,7 +42,11 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
4142
float dist = 0;
4243
for (int d = 0; d < D; ++d) {
4344
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
44-
dist += diff * diff;
45+
if (norm == 1) {
46+
dist += abs(diff);
47+
} else { // norm is 2 (default)
48+
dist += diff * diff;
49+
}
4550
}
4651
int size = static_cast<int>(q.size());
4752
if (size < K || dist < std::get<0>(q.top())) {
@@ -73,6 +78,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
7378
const at::Tensor& lengths1,
7479
const at::Tensor& lengths2,
7580
const at::Tensor& idxs,
81+
const int norm,
7682
const at::Tensor& grad_dists) {
7783
const int N = p1.size(0);
7884
const int P1 = p1.size(1);
@@ -104,8 +110,14 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
104110
continue;
105111
}
106112
for (int64_t d = 0; d < D; ++d) {
107-
const float diff =
108-
2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]);
113+
float diff = 0.0;
114+
if (norm == 1) {
115+
float sign = (p1_a[n][i1][d] > p2_a[n][i2][d]) ? 1.0 : -1.0;
116+
diff = grad_dists_a[n][i1][k] * sign;
117+
} else { // norm is 2 (default)
118+
diff = 2.0f * grad_dists_a[n][i1][k] *
119+
(p1_a[n][i1][d] - p2_a[n][i2][d]);
120+
}
109121
grad_p1_a[n][i1][d] += diff;
110122
grad_p2_a[n][i2][d] += -1.0f * diff;
111123
}

0 commit comments

Comments
 (0)