Skip to content

Commit fa272d5

Browse files
authored
Fix empty bbox error for YOLOv9 (#4024)
* Fix empty bbox error * Add unit test * precommit
1 parent 08fe9d6 commit fa272d5

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

src/otx/algo/detection/losses/yolov9_loss.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,20 @@ def __call__(self, target: Tensor, predict: tuple[Tensor, Tensor]) -> tuple[Tens
305305
predict (tuple[Tensor, Tensor]): The predicted class and bounding box.
306306
307307
Returns:
308-
tuple[Tensor, Tensor]: The aligned target tensor with (batch, targets, (class + 4)).
308+
tuple[Tensor, Tensor]: The aligned target tensors with (batch, targets, (class + 4)) and (batch, targets).
309309
"""
310310
predict_cls, predict_bbox = predict
311+
312+
# return if target has no gt information.
313+
n_targets = target.shape[1]
314+
if n_targets == 0:
315+
device = predict_bbox.device
316+
align_cls = torch.zeros_like(predict_cls, device=device)
317+
align_bbox = torch.zeros_like(predict_bbox, device=device)
318+
valid_mask = torch.zeros(predict_cls.shape[:2], dtype=bool, device=device)
319+
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
320+
return anchor_matched_targets, valid_mask
321+
311322
target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
312323
target_cls = target_cls.long().clamp(0)
313324

@@ -341,8 +352,8 @@ def __call__(self, target: Tensor, predict: tuple[Tensor, Tensor]) -> tuple[Tens
341352
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
342353
normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
343354
align_cls = align_cls * normalize_term * valid_mask[:, :, None]
344-
345-
return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
355+
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
356+
return anchor_matched_targets, valid_mask
346357

347358

348359
class YOLOv9Criterion(nn.Module):

tests/unit/algo/detection/losses/test_yolov9_loss.py

+15
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,21 @@ def test_call(self, box_matcher: BoxMatcher) -> None:
163163
assert align_targets.shape == torch.Size([1, 3, 14])
164164
assert valid_masks.shape == torch.Size([1, 3])
165165

166+
def test_call_with_empty_bbox(self, box_matcher: BoxMatcher) -> None:
167+
target = torch.zeros((1, 0, 5))
168+
169+
predict_cls = torch.rand((1, 8400, 10))
170+
predict_bbox = torch.rand((1, 8400, 4))
171+
predict = (predict_cls, predict_bbox)
172+
173+
align_targets, valid_masks = box_matcher(target, predict)
174+
175+
assert align_targets.shape == (1, 8400, 14)
176+
assert torch.all(align_targets == 0)
177+
178+
assert valid_masks.shape == (1, 8400)
179+
assert torch.all(~valid_masks)
180+
166181

167182
class TestYOLOv9Criterion:
168183
@pytest.fixture()

0 commit comments

Comments
 (0)