Skip to content

Commit 44702fd

Browse files
rchen152facebook-github-bot
authored andcommitted
Add "max" point reduction for chamfer distance
Summary: * Adds a "max" option for the point_reduction input to the chamfer_distance function. * When combining the x and y directions, maxes the losses instead of summing them when point_reduction="max". * Moves batch reduction to happen after the directions are combined. * Adds test_chamfer_point_reduction_max and test_single_directional_chamfer_point_reduction_max tests. Fixes #1838 Reviewed By: bottler Differential Revision: D60614661 fbshipit-source-id: 7879816acfda03e945bada951b931d2c522756eb
1 parent 7edaee7 commit 44702fd

File tree

2 files changed

+136
-35
lines changed

2 files changed

+136
-35
lines changed

Diff for: pytorch3d/loss/chamfer.py

+55-33
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ def _validate_chamfer_reduction_inputs(
2727
"""
2828
if batch_reduction is not None and batch_reduction not in ["mean", "sum"]:
2929
raise ValueError('batch_reduction must be one of ["mean", "sum"] or None')
30-
if point_reduction is not None and point_reduction not in ["mean", "sum"]:
31-
raise ValueError('point_reduction must be one of ["mean", "sum"] or None')
30+
if point_reduction is not None and point_reduction not in ["mean", "sum", "max"]:
31+
raise ValueError(
32+
'point_reduction must be one of ["mean", "sum", "max"] or None'
33+
)
3234
if point_reduction is None and batch_reduction is not None:
3335
raise ValueError("Batch reduction must be None if point_reduction is None")
3436

@@ -80,7 +82,6 @@ def _chamfer_distance_single_direction(
8082
x_normals,
8183
y_normals,
8284
weights,
83-
batch_reduction: Union[str, None],
8485
point_reduction: Union[str, None],
8586
norm: int,
8687
abs_cosine: bool,
@@ -103,11 +104,6 @@ def _chamfer_distance_single_direction(
103104
raise ValueError("weights cannot be negative.")
104105
if weights.sum() == 0.0:
105106
weights = weights.view(N, 1)
106-
if batch_reduction in ["mean", "sum"]:
107-
return (
108-
(x.sum((1, 2)) * weights).sum() * 0.0,
109-
(x.sum((1, 2)) * weights).sum() * 0.0,
110-
)
111107
return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
112108

113109
cham_norm_x = x.new_zeros(())
@@ -135,7 +131,10 @@ def _chamfer_distance_single_direction(
135131
if weights is not None:
136132
cham_norm_x *= weights.view(N, 1)
137133

138-
if point_reduction is not None:
134+
if point_reduction == "max":
135+
assert not return_normals
136+
cham_x = cham_x.max(1).values # (N,)
137+
elif point_reduction is not None:
139138
# Apply point reduction
140139
cham_x = cham_x.sum(1) # (N,)
141140
if return_normals:
@@ -146,22 +145,34 @@ def _chamfer_distance_single_direction(
146145
if return_normals:
147146
cham_norm_x /= x_lengths_clamped
148147

149-
if batch_reduction is not None:
150-
# batch_reduction == "sum"
151-
cham_x = cham_x.sum()
152-
if return_normals:
153-
cham_norm_x = cham_norm_x.sum()
154-
if batch_reduction == "mean":
155-
div = weights.sum() if weights is not None else max(N, 1)
156-
cham_x /= div
157-
if return_normals:
158-
cham_norm_x /= div
159-
160148
cham_dist = cham_x
161149
cham_normals = cham_norm_x if return_normals else None
162150
return cham_dist, cham_normals
163151

164152

153+
def _apply_batch_reduction(
154+
cham_x, cham_norm_x, weights, batch_reduction: Union[str, None]
155+
):
156+
if batch_reduction is None:
157+
return (cham_x, cham_norm_x)
158+
# batch_reduction == "sum"
159+
N = cham_x.shape[0]
160+
cham_x = cham_x.sum()
161+
if cham_norm_x is not None:
162+
cham_norm_x = cham_norm_x.sum()
163+
if batch_reduction == "mean":
164+
if weights is None:
165+
div = max(N, 1)
166+
elif weights.sum() == 0.0:
167+
div = 1
168+
else:
169+
div = weights.sum()
170+
cham_x /= div
171+
if cham_norm_x is not None:
172+
cham_norm_x /= div
173+
return (cham_x, cham_norm_x)
174+
175+
165176
def chamfer_distance(
166177
x,
167178
y,
@@ -197,7 +208,8 @@ def chamfer_distance(
197208
batch_reduction: Reduction operation to apply for the loss across the
198209
batch, can be one of ["mean", "sum"] or None.
199210
point_reduction: Reduction operation to apply for the loss across the
200-
points, can be one of ["mean", "sum"] or None.
211+
points, can be one of ["mean", "sum", "max"] or None. Using "max" leads to the
212+
Hausdorff distance.
201213
norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
202214
single_directional: If False (default), loss comes from both the distance between
203215
each point in x and its nearest neighbor in y and each point in y and its nearest
@@ -227,6 +239,10 @@ def chamfer_distance(
227239

228240
if not ((norm == 1) or (norm == 2)):
229241
raise ValueError("Support for 1 or 2 norm.")
242+
243+
if point_reduction == "max" and (x_normals is not None or y_normals is not None):
244+
raise ValueError('Normals must be None if point_reduction is "max"')
245+
230246
x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
231247
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
232248

@@ -238,13 +254,13 @@ def chamfer_distance(
238254
x_normals,
239255
y_normals,
240256
weights,
241-
batch_reduction,
242257
point_reduction,
243258
norm,
244259
abs_cosine,
245260
)
246261
if single_directional:
247-
return cham_x, cham_norm_x
262+
loss = cham_x
263+
loss_normals = cham_norm_x
248264
else:
249265
cham_y, cham_norm_y = _chamfer_distance_single_direction(
250266
y,
@@ -254,17 +270,23 @@ def chamfer_distance(
254270
y_normals,
255271
x_normals,
256272
weights,
257-
batch_reduction,
258273
point_reduction,
259274
norm,
260275
abs_cosine,
261276
)
262-
if point_reduction is not None:
263-
return (
264-
cham_x + cham_y,
265-
(cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
266-
)
267-
return (
268-
(cham_x, cham_y),
269-
(cham_norm_x, cham_norm_y) if cham_norm_x is not None else None,
270-
)
277+
if point_reduction == "max":
278+
loss = torch.maximum(cham_x, cham_y)
279+
loss_normals = None
280+
elif point_reduction is not None:
281+
loss = cham_x + cham_y
282+
if cham_norm_x is not None:
283+
loss_normals = cham_norm_x + cham_norm_y
284+
else:
285+
loss_normals = None
286+
else:
287+
loss = (cham_x, cham_y)
288+
if cham_norm_x is not None:
289+
loss_normals = (cham_norm_x, cham_norm_y)
290+
else:
291+
loss_normals = None
292+
return _apply_batch_reduction(loss, loss_normals, weights, batch_reduction)

Diff for: tests/test_chamfer.py

+81-2
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,85 @@ def test_single_direction_chamfer_point_reduction_none(self):
847847
loss, loss_norm, pred_loss[0], pred_loss_norm[0], p1, p11, p2, p22
848848
)
849849

850+
def test_chamfer_point_reduction_max(self):
851+
"""
852+
Compare output of vectorized chamfer loss with naive implementation
853+
for point_reduction = "max" and batch_reduction = None.
854+
"""
855+
N, P1, P2 = 7, 10, 18
856+
device = get_random_cuda_device()
857+
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
858+
p1 = points_normals.p1
859+
p2 = points_normals.p2
860+
weights = points_normals.weights
861+
p11 = p1.detach().clone()
862+
p22 = p2.detach().clone()
863+
p11.requires_grad = True
864+
p22.requires_grad = True
865+
866+
pred_loss, unused_pred_loss_norm = TestChamfer.chamfer_distance_naive(
867+
p1, p2, x_normals=None, y_normals=None
868+
)
869+
870+
loss, loss_norm = chamfer_distance(
871+
p11,
872+
p22,
873+
x_normals=None,
874+
y_normals=None,
875+
weights=weights,
876+
batch_reduction=None,
877+
point_reduction="max",
878+
)
879+
pred_loss_max = torch.maximum(
880+
pred_loss[0].max(1).values, pred_loss[1].max(1).values
881+
)
882+
pred_loss_max *= weights
883+
self.assertClose(loss, pred_loss_max)
884+
885+
self.assertIsNone(loss_norm)
886+
887+
# Check gradients
888+
self._check_gradients(loss, loss_norm, pred_loss_max, None, p1, p11, p2, p22)
889+
890+
def test_single_directional_chamfer_point_reduction_max(self):
891+
"""
892+
Compare output of vectorized single directional chamfer loss with naive implementation
893+
for point_reduction = "max" and batch_reduction = None.
894+
"""
895+
N, P1, P2 = 7, 10, 18
896+
device = get_random_cuda_device()
897+
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
898+
p1 = points_normals.p1
899+
p2 = points_normals.p2
900+
weights = points_normals.weights
901+
p11 = p1.detach().clone()
902+
p22 = p2.detach().clone()
903+
p11.requires_grad = True
904+
p22.requires_grad = True
905+
906+
pred_loss, unused_pred_loss_norm = TestChamfer.chamfer_distance_naive(
907+
p1, p2, x_normals=None, y_normals=None
908+
)
909+
910+
loss, loss_norm = chamfer_distance(
911+
p11,
912+
p22,
913+
x_normals=None,
914+
y_normals=None,
915+
weights=weights,
916+
batch_reduction=None,
917+
point_reduction="max",
918+
single_directional=True,
919+
)
920+
pred_loss_max = pred_loss[0].max(1).values
921+
pred_loss_max *= weights
922+
self.assertClose(loss, pred_loss_max)
923+
924+
self.assertIsNone(loss_norm)
925+
926+
# Check gradients
927+
self._check_gradients(loss, loss_norm, pred_loss_max, None, p1, p11, p2, p22)
928+
850929
def _check_gradients(
851930
self,
852931
loss,
@@ -1020,9 +1099,9 @@ def test_chamfer_joint_reduction(self):
10201099
with self.assertRaisesRegex(ValueError, "batch_reduction must be one of"):
10211100
chamfer_distance(p1, p2, weights=weights, batch_reduction="max")
10221101

1023-
# Error when point_reduction is not in ["mean", "sum"] or None.
1102+
# Error when point_reduction is not in ["mean", "sum", "max"] or None.
10241103
with self.assertRaisesRegex(ValueError, "point_reduction must be one of"):
1025-
chamfer_distance(p1, p2, weights=weights, point_reduction="max")
1104+
chamfer_distance(p1, p2, weights=weights, point_reduction="min")
10261105

10271106
def test_incorrect_weights(self):
10281107
N, P1, P2 = 16, 64, 128

0 commit comments

Comments
 (0)