Skip to content

Commit ccfb72c

Browse files
gkioxarifacebook-github-bot
authored andcommitted
small fix for iou3d
Summary: A small numerical fix for IoU for 3D boxes, fixes GH #992 * Adds a check for boxes with zero side areas (invalid boxes) * Fixes numerical issue when two boxes have coplanar sides Reviewed By: nikhilaravi Differential Revision: D33195691 fbshipit-source-id: 8a34b4d1f1e5ec2edb6d54143930da44bdde0906
1 parent 069c9fd commit ccfb72c

File tree

6 files changed

+202
-4
lines changed

6 files changed

+202
-4
lines changed

pytorch3d/csrc/iou_box3d/iou_box3d.cu

+2-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ __global__ void IoUBox3DKernel(
9090
for (int b2 = 0; b2 < box2_count; ++b2) {
9191
const bool is_coplanar =
9292
IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]);
93-
if (is_coplanar) {
93+
const float area = FaceArea(box1_intersect[b1]);
94+
if ((is_coplanar) && (area > kEpsilon)) {
9495
tri2_keep[b2].keep = false;
9596
}
9697
}

pytorch3d/csrc/iou_box3d/iou_box3d_cpu.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ std::tuple<at::Tensor, at::Tensor> IoUBox3DCpu(
8181
for (int b2 = 0; b2 < box2_intersect.size(); ++b2) {
8282
const bool is_coplanar =
8383
IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]);
84-
if (is_coplanar) {
84+
const float area = FaceArea(box1_intersect[b1]);
85+
if ((is_coplanar) && (area > kEpsilon)) {
8586
tri2_keep[b2] = 0;
8687
}
8788
}

pytorch3d/csrc/iou_box3d/iou_utils.cuh

+20
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,26 @@ FaceNormal(const float3 v0, const float3 v1, const float3 v2) {
138138
return n;
139139
}
140140

141+
// The area of the face defined by vertices (v0, v1, v2)
142+
// Define e0 to be the edge connecting (v1, v0)
143+
// Define e1 to be the edge connecting (v2, v0)
144+
// Area is the norm of the cross product of e0, e1 divided by 2.0
145+
//
146+
// Args
147+
// tri: FaceVerts of float3 coordinates of the vertices of the face
148+
//
149+
// Returns
150+
// float: area for the face
151+
//
152+
__device__ inline float FaceArea(const FaceVerts& tri) {
153+
// Get verts for face 1
154+
const float3 v0 = tri.v0;
155+
const float3 v1 = tri.v1;
156+
const float3 v2 = tri.v2;
157+
const float3 n = cross(v1 - v0, v2 - v0);
158+
return norm(n) / 2.0;
159+
}
160+
141161
// The normal of a box plane defined by the verts in `plane` with
142162
// the centroid of the box given by `center`.
143163
// Args

pytorch3d/csrc/iou_box3d/iou_utils.h

+20
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,26 @@ inline vec3<float> FaceNormal(vec3<float> v0, vec3<float> v1, vec3<float> v2) {
145145
return n;
146146
}
147147

148+
// The area of the face defined by vertices (v0, v1, v2)
149+
// Define e0 to be the edge connecting (v1, v0)
150+
// Define e1 to be the edge connecting (v2, v0)
151+
// Area is the norm of the cross product of e0, e1 divided by 2.0
152+
//
153+
// Args
154+
// tri: vec3 coordinates of the vertices of the face
155+
//
156+
// Returns
157+
// float: area for the face
158+
//
159+
inline float FaceArea(const std::vector<vec3<float>>& tri) {
160+
// Get verts for face
161+
const vec3<float> v0 = tri[0];
162+
const vec3<float> v1 = tri[1];
163+
const vec3<float> v2 = tri[2];
164+
const vec3<float> n = cross(v1 - v0, v2 - v0);
165+
return norm(n) / 2.0;
166+
}
167+
148168
// The normal of a box plane defined by the verts in `plane` with
149169
// the centroid of the box given by `center`.
150170
// Args

pytorch3d/ops/iou_box3d.py

+24
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,28 @@ def _check_coplanar(boxes: torch.Tensor, eps: float = 1e-4) -> None:
6969
return
7070

7171

72+
def _check_nonzero(boxes: torch.Tensor, eps: float = 1e-4) -> None:
73+
"""
74+
Checks that the sides of the box have a non zero area
75+
"""
76+
faces = torch.tensor(_box_triangles, dtype=torch.int64, device=boxes.device)
77+
# pyre-fixme[16]: `boxes` has no attribute `index_select`.
78+
verts = boxes.index_select(index=faces.view(-1), dim=1)
79+
B = boxes.shape[0]
80+
T, V = faces.shape
81+
# (B, T, 3, 3) -> (B, T, 3)
82+
v0, v1, v2 = verts.reshape(B, T, V, 3).unbind(2)
83+
84+
normals = torch.cross(v1 - v0, v2 - v0, dim=-1) # (B, T, 3)
85+
face_areas = normals.norm(dim=-1) / 2
86+
87+
if (face_areas < eps).any().item():
88+
msg = "Planes have zero areas"
89+
raise ValueError(msg)
90+
91+
return
92+
93+
7294
class _box3d_overlap(Function):
7395
"""
7496
Torch autograd Function wrapper for box3d_overlap C++/CUDA implementations.
@@ -138,6 +160,8 @@ def box3d_overlap(
138160

139161
_check_coplanar(boxes1, eps)
140162
_check_coplanar(boxes2, eps)
163+
_check_nonzero(boxes1, eps)
164+
_check_nonzero(boxes2, eps)
141165

142166
# pyre-fixme[16]: `_box3d_overlap` has no attribute `apply`.
143167
vol, iou = _box3d_overlap.apply(boxes1, boxes2)

tests/test_iou_box3d.py

+134-2
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ def _test_iou(self, overlap_fn, device):
111111
self.assertClose(
112112
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
113113
)
114+
# symmetry
115+
vol, iou = overlap_fn(box2[None], box1[None])
116+
self.assertClose(
117+
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
118+
)
114119

115120
# 3rd test
116121
dd = random.random()
@@ -119,6 +124,11 @@ def _test_iou(self, overlap_fn, device):
119124
self.assertClose(
120125
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
121126
)
127+
# symmetry
128+
vol, _ = overlap_fn(box2[None], box1[None])
129+
self.assertClose(
130+
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
131+
)
122132

123133
# 4th test
124134
ddx, ddy, ddz = random.random(), random.random(), random.random()
@@ -132,6 +142,16 @@ def _test_iou(self, overlap_fn, device):
132142
dtype=vol.dtype,
133143
),
134144
)
145+
# symmetry
146+
vol, _ = overlap_fn(box2[None], box1[None])
147+
self.assertClose(
148+
vol,
149+
torch.tensor(
150+
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
151+
device=vol.device,
152+
dtype=vol.dtype,
153+
),
154+
)
135155

136156
# Also check IoU is 1 when computing overlap with the same shifted box
137157
vol, iou = overlap_fn(box2[None], box2[None])
@@ -152,6 +172,16 @@ def _test_iou(self, overlap_fn, device):
152172
dtype=vol.dtype,
153173
),
154174
)
175+
# symmetry
176+
vol, _ = overlap_fn(box2r[None], box1r[None])
177+
self.assertClose(
178+
vol,
179+
torch.tensor(
180+
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
181+
device=vol.device,
182+
dtype=vol.dtype,
183+
),
184+
)
155185

156186
# 6th test
157187
ddx, ddy, ddz = random.random(), random.random(), random.random()
@@ -170,6 +200,17 @@ def _test_iou(self, overlap_fn, device):
170200
),
171201
atol=1e-7,
172202
)
203+
# symmetry
204+
vol, _ = overlap_fn(box2r[None], box1r[None])
205+
self.assertClose(
206+
vol,
207+
torch.tensor(
208+
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
209+
device=vol.device,
210+
dtype=vol.dtype,
211+
),
212+
atol=1e-7,
213+
)
173214

174215
# 7th test: hand coded example and test with meshlab output
175216

@@ -214,6 +255,10 @@ def _test_iou(self, overlap_fn, device):
214255
vol, iou = overlap_fn(box1r[None], box2r[None])
215256
self.assertClose(vol, torch.tensor([[vol_inters]], device=device), atol=1e-1)
216257
self.assertClose(iou, torch.tensor([[iou_mesh]], device=device), atol=1e-1)
258+
# symmetry
259+
vol, iou = overlap_fn(box2r[None], box1r[None])
260+
self.assertClose(vol, torch.tensor([[vol_inters]], device=device), atol=1e-1)
261+
self.assertClose(iou, torch.tensor([[iou_mesh]], device=device), atol=1e-1)
217262

218263
# 8th test: compare with sampling
219264
# create box1
@@ -232,14 +277,20 @@ def _test_iou(self, overlap_fn, device):
232277
iou_sampling = self._box3d_overlap_sampling_batched(
233278
box1r[None], box2r[None], num_samples=10000
234279
)
235-
280+
self.assertClose(iou, iou_sampling, atol=1e-2)
281+
# symmetry
282+
vol, iou = overlap_fn(box2r[None], box1r[None])
236283
self.assertClose(iou, iou_sampling, atol=1e-2)
237284

238285
# 9th test: non overlapping boxes, iou = 0.0
239286
box2 = box1 + torch.tensor([[0.0, 100.0, 0.0]], device=device)
240287
vol, iou = overlap_fn(box1[None], box2[None])
241288
self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
242289
self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
290+
# symmetry
291+
vol, iou = overlap_fn(box2[None], box1[None])
292+
self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
293+
self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
243294

244295
# 10th test: Non coplanar verts in a plane
245296
box10 = box1 + torch.rand((8, 3), dtype=torch.float32, device=device)
@@ -284,6 +335,56 @@ def _test_iou(self, overlap_fn, device):
284335
vols, ious = overlap_fn(box_skew_1[None], box_skew_2[None])
285336
self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1)
286337
self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1)
338+
# symmetry
339+
vols, ious = overlap_fn(box_skew_2[None], box_skew_1[None])
340+
self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1)
341+
self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1)
342+
343+
# 12th test: Zero area bounding box (from GH issue #992)
344+
box12a = torch.tensor(
345+
[
346+
[-1.0000, -1.0000, -0.5000],
347+
[1.0000, -1.0000, -0.5000],
348+
[1.0000, 1.0000, -0.5000],
349+
[-1.0000, 1.0000, -0.5000],
350+
[-1.0000, -1.0000, 0.5000],
351+
[1.0000, -1.0000, 0.5000],
352+
[1.0000, 1.0000, 0.5000],
353+
[-1.0000, 1.0000, 0.5000],
354+
],
355+
device=device,
356+
dtype=torch.float32,
357+
)
358+
359+
box12b = torch.tensor(
360+
[
361+
[0.0, 0.0, 0.0],
362+
[0.0, 0.0, 0.0],
363+
[0.0, 0.0, 0.0],
364+
[0.0, 0.0, 0.0],
365+
[0.0, 0.0, 0.0],
366+
[0.0, 0.0, 0.0],
367+
[0.0, 0.0, 0.0],
368+
[0.0, 0.0, 0.0],
369+
],
370+
device=device,
371+
dtype=torch.float32,
372+
)
373+
msg = "Planes have zero areas"
374+
with self.assertRaisesRegex(ValueError, msg):
375+
overlap_fn(box12a[None], box12b[None])
376+
# symmetry
377+
with self.assertRaisesRegex(ValueError, msg):
378+
overlap_fn(box12b[None], box12a[None])
379+
380+
# 13th test: From GH issue #992
381+
# Zero area coplanar face after intersection
382+
ctrs = torch.tensor([[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0]])
383+
whl = torch.tensor([[2.0, 2.0, 2.0], [2.0, 2, 2]])
384+
box13a = TestIoU3D.create_box(ctrs[0], whl[0])
385+
box13b = TestIoU3D.create_box(ctrs[1], whl[1])
386+
vol, iou = overlap_fn(box13a[None], box13b[None])
387+
self.assertClose(vol, torch.tensor([[2.0]], device=vol.device, dtype=vol.dtype))
287388

288389
def _test_real_boxes(self, overlap_fn, device):
289390
data_filename = "./real_boxes.pkl"
@@ -577,6 +678,13 @@ def box_planar_dir(box: torch.Tensor, eps=1e-4) -> torch.Tensor:
577678
msg = "Plane vertices are not coplanar"
578679
raise ValueError(msg)
579680

681+
# Check all faces have non zero area
682+
area1 = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2
683+
area2 = torch.cross(v3 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2
684+
if (area1 < eps).any().item() or (area2 < eps).any().item():
685+
msg = "Planes have zero areas"
686+
raise ValueError(msg)
687+
580688
# We can write: `ctr = v0 + a * e0 + b * e1 + c * n`, (1).
581689
# With <e0, n> = 0 and <e1, n> = 0, where <.,.> refers to the dot product,
582690
# since that e0 is orthogonal to n. Same for e1.
@@ -607,6 +715,27 @@ def box_planar_dir(box: torch.Tensor, eps=1e-4) -> torch.Tensor:
607715
return n
608716

609717

718+
def tri_verts_area(tri_verts: torch.Tensor) -> torch.Tensor:
719+
"""
720+
Computes the area of the triangle faces in tri_verts
721+
Args:
722+
tri_verts: tensor of shape (T, 3, 3)
723+
Returns:
724+
areas: the area of the triangles (T, 1)
725+
"""
726+
add_dim = False
727+
if tri_verts.ndim == 2:
728+
tri_verts = tri_verts.unsqueeze(0)
729+
add_dim = True
730+
731+
v0, v1, v2 = tri_verts.unbind(1)
732+
areas = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2.0
733+
734+
if add_dim:
735+
areas = areas[0]
736+
return areas
737+
738+
610739
def box_volume(box: torch.Tensor) -> torch.Tensor:
611740
"""
612741
Computes the volume of each box in boxes.
@@ -988,7 +1117,10 @@ def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor):
9881117
keep2 = torch.ones((tri_verts2.shape[0],), device=device, dtype=torch.bool)
9891118
for i1 in range(tri_verts1.shape[0]):
9901119
for i2 in range(tri_verts2.shape[0]):
991-
if coplanar_tri_faces(tri_verts1[i1], tri_verts2[i2]):
1120+
if (
1121+
coplanar_tri_faces(tri_verts1[i1], tri_verts2[i2])
1122+
and tri_verts_area(tri_verts1[i1]) > 1e-4
1123+
):
9921124
keep2[i2] = 0
9931125
keep2 = keep2.nonzero()[:, 0]
9941126
tri_verts2 = tri_verts2[keep2]

0 commit comments

Comments
 (0)