@@ -27,8 +27,10 @@ def _validate_chamfer_reduction_inputs(
27
27
"""
28
28
if batch_reduction is not None and batch_reduction not in ["mean" , "sum" ]:
29
29
raise ValueError ('batch_reduction must be one of ["mean", "sum"] or None' )
30
- if point_reduction is not None and point_reduction not in ["mean" , "sum" ]:
31
- raise ValueError ('point_reduction must be one of ["mean", "sum"] or None' )
30
+ if point_reduction is not None and point_reduction not in ["mean" , "sum" , "max" ]:
31
+ raise ValueError (
32
+ 'point_reduction must be one of ["mean", "sum", "max"] or None'
33
+ )
32
34
if point_reduction is None and batch_reduction is not None :
33
35
raise ValueError ("Batch reduction must be None if point_reduction is None" )
34
36
@@ -80,7 +82,6 @@ def _chamfer_distance_single_direction(
80
82
x_normals ,
81
83
y_normals ,
82
84
weights ,
83
- batch_reduction : Union [str , None ],
84
85
point_reduction : Union [str , None ],
85
86
norm : int ,
86
87
abs_cosine : bool ,
@@ -103,11 +104,6 @@ def _chamfer_distance_single_direction(
103
104
raise ValueError ("weights cannot be negative." )
104
105
if weights .sum () == 0.0 :
105
106
weights = weights .view (N , 1 )
106
- if batch_reduction in ["mean" , "sum" ]:
107
- return (
108
- (x .sum ((1 , 2 )) * weights ).sum () * 0.0 ,
109
- (x .sum ((1 , 2 )) * weights ).sum () * 0.0 ,
110
- )
111
107
return ((x .sum ((1 , 2 )) * weights ) * 0.0 , (x .sum ((1 , 2 )) * weights ) * 0.0 )
112
108
113
109
cham_norm_x = x .new_zeros (())
@@ -135,7 +131,10 @@ def _chamfer_distance_single_direction(
135
131
if weights is not None :
136
132
cham_norm_x *= weights .view (N , 1 )
137
133
138
- if point_reduction is not None :
134
+ if point_reduction == "max" :
135
+ assert not return_normals
136
+ cham_x = cham_x .max (1 ).values # (N,)
137
+ elif point_reduction is not None :
139
138
# Apply point reduction
140
139
cham_x = cham_x .sum (1 ) # (N,)
141
140
if return_normals :
@@ -146,22 +145,34 @@ def _chamfer_distance_single_direction(
146
145
if return_normals :
147
146
cham_norm_x /= x_lengths_clamped
148
147
149
- if batch_reduction is not None :
150
- # batch_reduction == "sum"
151
- cham_x = cham_x .sum ()
152
- if return_normals :
153
- cham_norm_x = cham_norm_x .sum ()
154
- if batch_reduction == "mean" :
155
- div = weights .sum () if weights is not None else max (N , 1 )
156
- cham_x /= div
157
- if return_normals :
158
- cham_norm_x /= div
159
-
160
148
cham_dist = cham_x
161
149
cham_normals = cham_norm_x if return_normals else None
162
150
return cham_dist , cham_normals
163
151
164
152
153
+ def _apply_batch_reduction (
154
+ cham_x , cham_norm_x , weights , batch_reduction : Union [str , None ]
155
+ ):
156
+ if batch_reduction is None :
157
+ return (cham_x , cham_norm_x )
158
+ # batch_reduction == "sum"
159
+ N = cham_x .shape [0 ]
160
+ cham_x = cham_x .sum ()
161
+ if cham_norm_x is not None :
162
+ cham_norm_x = cham_norm_x .sum ()
163
+ if batch_reduction == "mean" :
164
+ if weights is None :
165
+ div = max (N , 1 )
166
+ elif weights .sum () == 0.0 :
167
+ div = 1
168
+ else :
169
+ div = weights .sum ()
170
+ cham_x /= div
171
+ if cham_norm_x is not None :
172
+ cham_norm_x /= div
173
+ return (cham_x , cham_norm_x )
174
+
175
+
165
176
def chamfer_distance (
166
177
x ,
167
178
y ,
@@ -197,7 +208,8 @@ def chamfer_distance(
197
208
batch_reduction: Reduction operation to apply for the loss across the
198
209
batch, can be one of ["mean", "sum"] or None.
199
210
point_reduction: Reduction operation to apply for the loss across the
200
- points, can be one of ["mean", "sum"] or None.
211
+ points, can be one of ["mean", "sum", "max"] or None. Using "max" leads to the
212
+ Hausdorff distance.
201
213
norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
202
214
single_directional: If False (default), loss comes from both the distance between
203
215
each point in x and its nearest neighbor in y and each point in y and its nearest
@@ -227,6 +239,10 @@ def chamfer_distance(
227
239
228
240
if not ((norm == 1 ) or (norm == 2 )):
229
241
raise ValueError ("Support for 1 or 2 norm." )
242
+
243
+ if point_reduction == "max" and (x_normals is not None or y_normals is not None ):
244
+ raise ValueError ('Normals must be None if point_reduction is "max"' )
245
+
230
246
x , x_lengths , x_normals = _handle_pointcloud_input (x , x_lengths , x_normals )
231
247
y , y_lengths , y_normals = _handle_pointcloud_input (y , y_lengths , y_normals )
232
248
@@ -238,13 +254,13 @@ def chamfer_distance(
238
254
x_normals ,
239
255
y_normals ,
240
256
weights ,
241
- batch_reduction ,
242
257
point_reduction ,
243
258
norm ,
244
259
abs_cosine ,
245
260
)
246
261
if single_directional :
247
- return cham_x , cham_norm_x
262
+ loss = cham_x
263
+ loss_normals = cham_norm_x
248
264
else :
249
265
cham_y , cham_norm_y = _chamfer_distance_single_direction (
250
266
y ,
@@ -254,17 +270,23 @@ def chamfer_distance(
254
270
y_normals ,
255
271
x_normals ,
256
272
weights ,
257
- batch_reduction ,
258
273
point_reduction ,
259
274
norm ,
260
275
abs_cosine ,
261
276
)
262
- if point_reduction is not None :
263
- return (
264
- cham_x + cham_y ,
265
- (cham_norm_x + cham_norm_y ) if cham_norm_x is not None else None ,
266
- )
267
- return (
268
- (cham_x , cham_y ),
269
- (cham_norm_x , cham_norm_y ) if cham_norm_x is not None else None ,
270
- )
277
+ if point_reduction == "max" :
278
+ loss = torch .maximum (cham_x , cham_y )
279
+ loss_normals = None
280
+ elif point_reduction is not None :
281
+ loss = cham_x + cham_y
282
+ if cham_norm_x is not None :
283
+ loss_normals = cham_norm_x + cham_norm_y
284
+ else :
285
+ loss_normals = None
286
+ else :
287
+ loss = (cham_x , cham_y )
288
+ if cham_norm_x is not None :
289
+ loss_normals = (cham_norm_x , cham_norm_y )
290
+ else :
291
+ loss_normals = None
292
+ return _apply_batch_reduction (loss , loss_normals , weights , batch_reduction )
0 commit comments