@@ -305,9 +305,20 @@ def __call__(self, target: Tensor, predict: tuple[Tensor, Tensor]) -> tuple[Tens
305
305
predict (tuple[Tensor, Tensor]): The predicted class and bounding box.
306
306
307
307
Returns:
308
- tuple[Tensor, Tensor]: The aligned target tensor with (batch, targets, (class + 4)).
308
+ tuple[Tensor, Tensor]: The aligned target tensors with (batch, targets, (class + 4)) and (batch, targets ).
309
309
"""
310
310
predict_cls , predict_bbox = predict
311
+
312
+ # return if target has no gt information.
313
+ n_targets = target .shape [1 ]
314
+ if n_targets == 0 :
315
+ device = predict_bbox .device
316
+ align_cls = torch .zeros_like (predict_cls , device = device )
317
+ align_bbox = torch .zeros_like (predict_bbox , device = device )
318
+ valid_mask = torch .zeros (predict_cls .shape [:2 ], dtype = bool , device = device )
319
+ anchor_matched_targets = torch .cat ([align_cls , align_bbox ], dim = - 1 )
320
+ return anchor_matched_targets , valid_mask
321
+
311
322
target_cls , target_bbox = target .split ([1 , 4 ], dim = - 1 ) # B x N x (C B) -> B x N x C, B x N x B
312
323
target_cls = target_cls .long ().clamp (0 )
313
324
@@ -341,8 +352,8 @@ def __call__(self, target: Tensor, predict: tuple[Tensor, Tensor]) -> tuple[Tens
341
352
normalize_term = (target_matrix / (max_target + 1e-9 )) * max_iou
342
353
normalize_term = normalize_term .permute (0 , 2 , 1 ).gather (2 , unique_indices )
343
354
align_cls = align_cls * normalize_term * valid_mask [:, :, None ]
344
-
345
- return torch . cat ([ align_cls , align_bbox ], dim = - 1 ), valid_mask . bool ()
355
+ anchor_matched_targets = torch . cat ([ align_cls , align_bbox ], dim = - 1 )
356
+ return anchor_matched_targets , valid_mask
346
357
347
358
348
359
class YOLOv9Criterion (nn .Module ):
0 commit comments