Skip to content

Commit 471b126

Browse files
winnie1994facebook-github-bot
authored andcommitted
add min_triangle_area argument to IsInsideTriangle
Summary: 1. changed IsInsideTriangle in geometry_utils to take in min_triangle_area parameter instead of hardcoded value 2. updated point_mesh_cpu.cpp and point_mesh_cuda.[h/cu] to adapt to changes in geometry_utils function signatures 3. updated point_mesh_distance.py and test_point_mesh_distance.py to modify _C. calls Reviewed By: bottler Differential Revision: D34459764 fbshipit-source-id: 0549e78713c6d68f03d85fb597a13dd88e09b686
1 parent 4d043fc commit 471b126

File tree

7 files changed

+344
-134
lines changed

7 files changed

+344
-134
lines changed

Diff for: pytorch3d/csrc/point_mesh/point_mesh_cpu.cpp

+72-42
Original file line numberDiff line numberDiff line change
@@ -57,29 +57,33 @@ void IncrementPoint(at::TensorAccessor<T, 1>&& t, const vec3<T>& point) {
5757
template <typename T>
5858
T HullDistance(
5959
const std::array<vec3<T>, 1>& a,
60-
const std::array<vec3<T>, 2>& b) {
60+
const std::array<vec3<T>, 2>& b,
61+
const double /*min_triangle_area*/) {
6162
using std::get;
6263
return PointLine3DistanceForward(get<0>(a), get<0>(b), get<1>(b));
6364
}
6465
template <typename T>
6566
T HullDistance(
6667
const std::array<vec3<T>, 1>& a,
67-
const std::array<vec3<T>, 3>& b) {
68+
const std::array<vec3<T>, 3>& b,
69+
const double min_triangle_area) {
6870
using std::get;
6971
return PointTriangle3DistanceForward(
70-
get<0>(a), get<0>(b), get<1>(b), get<2>(b));
72+
get<0>(a), get<0>(b), get<1>(b), get<2>(b), min_triangle_area);
7173
}
7274
template <typename T>
7375
T HullDistance(
7476
const std::array<vec3<T>, 2>& a,
75-
const std::array<vec3<T>, 1>& b) {
76-
return HullDistance(b, a);
77+
const std::array<vec3<T>, 1>& b,
78+
const double /*min_triangle_area*/) {
79+
return HullDistance(b, a, 1);
7780
}
7881
template <typename T>
7982
T HullDistance(
8083
const std::array<vec3<T>, 3>& a,
81-
const std::array<vec3<T>, 1>& b) {
82-
return HullDistance(b, a);
84+
const std::array<vec3<T>, 1>& b,
85+
const double min_triangle_area) {
86+
return HullDistance(b, a, min_triangle_area);
8387
}
8488

8589
template <typename T>
@@ -88,7 +92,8 @@ void HullHullDistanceBackward(
8892
const std::array<vec3<T>, 2>& b,
8993
T grad_dist,
9094
at::TensorAccessor<T, 1>&& grad_a,
91-
at::TensorAccessor<T, 2>&& grad_b) {
95+
at::TensorAccessor<T, 2>&& grad_b,
96+
const double /*min_triangle_area*/) {
9297
using std::get;
9398
auto res =
9499
PointLine3DistanceBackward(get<0>(a), get<0>(b), get<1>(b), grad_dist);
@@ -102,10 +107,11 @@ void HullHullDistanceBackward(
102107
const std::array<vec3<T>, 3>& b,
103108
T grad_dist,
104109
at::TensorAccessor<T, 1>&& grad_a,
105-
at::TensorAccessor<T, 2>&& grad_b) {
110+
at::TensorAccessor<T, 2>&& grad_b,
111+
const double min_triangle_area) {
106112
using std::get;
107113
auto res = PointTriangle3DistanceBackward(
108-
get<0>(a), get<0>(b), get<1>(b), get<2>(b), grad_dist);
114+
get<0>(a), get<0>(b), get<1>(b), get<2>(b), grad_dist, min_triangle_area);
109115
IncrementPoint(std::move(grad_a), get<0>(res));
110116
IncrementPoint(grad_b[0], get<1>(res));
111117
IncrementPoint(grad_b[1], get<2>(res));
@@ -117,19 +123,21 @@ void HullHullDistanceBackward(
117123
const std::array<vec3<T>, 1>& b,
118124
T grad_dist,
119125
at::TensorAccessor<T, 2>&& grad_a,
120-
at::TensorAccessor<T, 1>&& grad_b) {
126+
at::TensorAccessor<T, 1>&& grad_b,
127+
const double min_triangle_area) {
121128
return HullHullDistanceBackward(
122-
b, a, grad_dist, std::move(grad_b), std::move(grad_a));
129+
b, a, grad_dist, std::move(grad_b), std::move(grad_a), min_triangle_area);
123130
}
124131
template <typename T>
125132
void HullHullDistanceBackward(
126133
const std::array<vec3<T>, 2>& a,
127134
const std::array<vec3<T>, 1>& b,
128135
T grad_dist,
129136
at::TensorAccessor<T, 2>&& grad_a,
130-
at::TensorAccessor<T, 1>&& grad_b) {
137+
at::TensorAccessor<T, 1>&& grad_b,
138+
const double /*min_triangle_area*/) {
131139
return HullHullDistanceBackward(
132-
b, a, grad_dist, std::move(grad_b), std::move(grad_a));
140+
b, a, grad_dist, std::move(grad_b), std::move(grad_a), 1);
133141
}
134142

135143
template <int H>
@@ -150,7 +158,8 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceForwardCpu(
150158
const at::Tensor& as,
151159
const at::Tensor& as_first_idx,
152160
const at::Tensor& bs,
153-
const at::Tensor& bs_first_idx) {
161+
const at::Tensor& bs_first_idx,
162+
const double min_triangle_area) {
154163
const int64_t A_N = as.size(0);
155164
const int64_t B_N = bs.size(0);
156165
const int64_t BATCHES = as_first_idx.size(0);
@@ -190,7 +199,8 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceForwardCpu(
190199
size_t min_idx = 0;
191200
auto a = ExtractHull<H1>(as_a[a_n]);
192201
for (int64_t b_n = b_batch_start; b_n < b_batch_end; ++b_n) {
193-
float dist = HullDistance(a, ExtractHull<H2>(bs_a[b_n]));
202+
float dist =
203+
HullDistance(a, ExtractHull<H2>(bs_a[b_n]), min_triangle_area);
194204
if (dist <= min_dist) {
195205
min_dist = dist;
196206
min_idx = b_n;
@@ -208,7 +218,8 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceBackwardCpu(
208218
const at::Tensor& as,
209219
const at::Tensor& bs,
210220
const at::Tensor& idx_bs,
211-
const at::Tensor& grad_dists) {
221+
const at::Tensor& grad_dists,
222+
const double min_triangle_area) {
212223
const int64_t A_N = as.size(0);
213224

214225
TORCH_CHECK(idx_bs.size(0) == A_N);
@@ -230,15 +241,21 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceBackwardCpu(
230241
auto a = ExtractHull<H1>(as_a[a_n]);
231242
auto b = ExtractHull<H2>(bs_a[idx_bs_a[a_n]]);
232243
HullHullDistanceBackward(
233-
a, b, grad_dists_a[a_n], grad_as_a[a_n], grad_bs_a[idx_bs_a[a_n]]);
244+
a,
245+
b,
246+
grad_dists_a[a_n],
247+
grad_as_a[a_n],
248+
grad_bs_a[idx_bs_a[a_n]],
249+
min_triangle_area);
234250
}
235251
return std::make_tuple(grad_as, grad_bs);
236252
}
237253

238254
template <int H>
239255
torch::Tensor PointHullArrayDistanceForwardCpu(
240256
const torch::Tensor& points,
241-
const torch::Tensor& bs) {
257+
const torch::Tensor& bs,
258+
const double min_triangle_area) {
242259
const int64_t P = points.size(0);
243260
const int64_t B_N = bs.size(0);
244261

@@ -254,7 +271,7 @@ torch::Tensor PointHullArrayDistanceForwardCpu(
254271
auto dest = dists_a[p];
255272
for (int64_t b_n = 0; b_n < B_N; ++b_n) {
256273
auto b = ExtractHull<H>(bs_a[b_n]);
257-
dest[b_n] = HullDistance(point, b);
274+
dest[b_n] = HullDistance(point, b, min_triangle_area);
258275
}
259276
}
260277
return dists;
@@ -264,7 +281,8 @@ template <int H>
264281
std::tuple<at::Tensor, at::Tensor> PointHullArrayDistanceBackwardCpu(
265282
const at::Tensor& points,
266283
const at::Tensor& bs,
267-
const at::Tensor& grad_dists) {
284+
const at::Tensor& grad_dists,
285+
const double min_triangle_area) {
268286
const int64_t P = points.size(0);
269287
const int64_t B_N = bs.size(0);
270288

@@ -287,7 +305,12 @@ std::tuple<at::Tensor, at::Tensor> PointHullArrayDistanceBackwardCpu(
287305
for (int64_t b_n = 0; b_n < B_N; ++b_n) {
288306
auto b = ExtractHull<H>(bs_a[b_n]);
289307
HullHullDistanceBackward(
290-
point, b, grad_dist[b_n], std::move(grad_point), grad_bs_a[b_n]);
308+
point,
309+
b,
310+
grad_dist[b_n],
311+
std::move(grad_point),
312+
grad_bs_a[b_n],
313+
min_triangle_area);
291314
}
292315
}
293316
return std::make_tuple(grad_points, grad_bs);
@@ -299,63 +322,70 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForwardCpu(
299322
const torch::Tensor& points,
300323
const torch::Tensor& points_first_idx,
301324
const torch::Tensor& tris,
302-
const torch::Tensor& tris_first_idx) {
325+
const torch::Tensor& tris_first_idx,
326+
const double min_triangle_area) {
303327
return HullHullDistanceForwardCpu<1, 3>(
304-
points, points_first_idx, tris, tris_first_idx);
328+
points, points_first_idx, tris, tris_first_idx, min_triangle_area);
305329
}
306330

307331
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackwardCpu(
308332
const torch::Tensor& points,
309333
const torch::Tensor& tris,
310334
const torch::Tensor& idx_points,
311-
const torch::Tensor& grad_dists) {
335+
const torch::Tensor& grad_dists,
336+
const double min_triangle_area) {
312337
return HullHullDistanceBackwardCpu<1, 3>(
313-
points, tris, idx_points, grad_dists);
338+
points, tris, idx_points, grad_dists, min_triangle_area);
314339
}
315340

316341
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForwardCpu(
317342
const torch::Tensor& points,
318343
const torch::Tensor& points_first_idx,
319344
const torch::Tensor& tris,
320-
const torch::Tensor& tris_first_idx) {
345+
const torch::Tensor& tris_first_idx,
346+
const double min_triangle_area) {
321347
return HullHullDistanceForwardCpu<3, 1>(
322-
tris, tris_first_idx, points, points_first_idx);
348+
tris, tris_first_idx, points, points_first_idx, min_triangle_area);
323349
}
324350

325351
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackwardCpu(
326352
const torch::Tensor& points,
327353
const torch::Tensor& tris,
328354
const torch::Tensor& idx_tris,
329-
const torch::Tensor& grad_dists) {
330-
auto res =
331-
HullHullDistanceBackwardCpu<3, 1>(tris, points, idx_tris, grad_dists);
355+
const torch::Tensor& grad_dists,
356+
const double min_triangle_area) {
357+
auto res = HullHullDistanceBackwardCpu<3, 1>(
358+
tris, points, idx_tris, grad_dists, min_triangle_area);
332359
return std::make_tuple(std::get<1>(res), std::get<0>(res));
333360
}
334361

335362
torch::Tensor PointEdgeArrayDistanceForwardCpu(
336363
const torch::Tensor& points,
337364
const torch::Tensor& segms) {
338-
return PointHullArrayDistanceForwardCpu<2>(points, segms);
365+
return PointHullArrayDistanceForwardCpu<2>(points, segms, 1);
339366
}
340367

341368
std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCpu(
342369
const at::Tensor& points,
343370
const at::Tensor& tris,
344-
const at::Tensor& grad_dists) {
345-
return PointHullArrayDistanceBackwardCpu<3>(points, tris, grad_dists);
371+
const at::Tensor& grad_dists,
372+
const double min_triangle_area) {
373+
return PointHullArrayDistanceBackwardCpu<3>(
374+
points, tris, grad_dists, min_triangle_area);
346375
}
347376

348377
torch::Tensor PointFaceArrayDistanceForwardCpu(
349378
const torch::Tensor& points,
350-
const torch::Tensor& tris) {
351-
return PointHullArrayDistanceForwardCpu<3>(points, tris);
379+
const torch::Tensor& tris,
380+
const double min_triangle_area) {
381+
return PointHullArrayDistanceForwardCpu<3>(points, tris, min_triangle_area);
352382
}
353383

354384
std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCpu(
355385
const at::Tensor& points,
356386
const at::Tensor& segms,
357387
const at::Tensor& grad_dists) {
358-
return PointHullArrayDistanceBackwardCpu<2>(points, segms, grad_dists);
388+
return PointHullArrayDistanceBackwardCpu<2>(points, segms, grad_dists, 1);
359389
}
360390

361391
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCpu(
@@ -365,7 +395,7 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCpu(
365395
const torch::Tensor& segms_first_idx,
366396
const int64_t /*max_points*/) {
367397
return HullHullDistanceForwardCpu<1, 2>(
368-
points, points_first_idx, segms, segms_first_idx);
398+
points, points_first_idx, segms, segms_first_idx, 1);
369399
}
370400

371401
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCpu(
@@ -374,7 +404,7 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCpu(
374404
const torch::Tensor& idx_points,
375405
const torch::Tensor& grad_dists) {
376406
return HullHullDistanceBackwardCpu<1, 2>(
377-
points, segms, idx_points, grad_dists);
407+
points, segms, idx_points, grad_dists, 1);
378408
}
379409

380410
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCpu(
@@ -384,15 +414,15 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCpu(
384414
const torch::Tensor& segms_first_idx,
385415
const int64_t /*max_segms*/) {
386416
return HullHullDistanceForwardCpu<2, 1>(
387-
segms, segms_first_idx, points, points_first_idx);
417+
segms, segms_first_idx, points, points_first_idx, 1);
388418
}
389419

390420
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCpu(
391421
const torch::Tensor& points,
392422
const torch::Tensor& segms,
393423
const torch::Tensor& idx_segms,
394424
const torch::Tensor& grad_dists) {
395-
auto res =
396-
HullHullDistanceBackwardCpu<2, 1>(segms, points, idx_segms, grad_dists);
425+
auto res = HullHullDistanceBackwardCpu<2, 1>(
426+
segms, points, idx_segms, grad_dists, 1);
397427
return std::make_tuple(std::get<1>(res), std::get<0>(res));
398428
}

0 commit comments

Comments
 (0)