Skip to content

Commit

Permalink
🔨 fix: BoxMatcher.__call__ now returns all zero anchor matched target…
Browse files Browse the repository at this point in the history
…s and all False valid mask, if input target has zero annotations in it. (#88)

docs: Updated docstrings.
  • Loading branch information
Abdul-Mukit authored Oct 11, 2024
1 parent dea5a8a commit e53ff09
Showing 1 changed file with 31 additions and 6 deletions.
37 changes: 31 additions & 6 deletions yolo/utils/bounding_box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,37 @@ def filter_duplicates(self, target_matrix: Tensor):
return unique_indices[..., None]

def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
"""
1. For each anchor prediction, find the highest suitability targets
2. Select the targets
2. Noramlize the class probilities of targets
"""Matches each target to the most suitable anchor.
1. For each anchor prediction, find the highest suitability targets.
2. Match target to the best anchor.
3. Noramlize the class probilities of targets.
Args:
target: The ground truth class and bounding box information
as tensor of size [batch x targets x 5].
predict: Tuple of predicted class and bounding box tensors.
Class tensor is of size [batch x anchors x class]
Bounding box tensor is of size [batch x anchors x 4].
Returns:
anchor_matched_targets: Tensor of size [batch x anchors x (class + 4)].
A tensor assigning each target/gt to the best fitting anchor.
The class probabilities are normalized.
valid_mask: Bool tensor of shape [batch x anchors].
True if a anchor has a target/gt assigned to it.
"""
predict_cls, predict_bbox = predict

# return if target has no gt information.
n_targets = target.shape[1]
if n_targets == 0:
device = predict_bbox.device
align_cls = torch.zeros_like(predict_cls, device=device)
align_bbox = torch.zeros_like(predict_bbox, device=device)
valid_mask = torch.zeros(predict_cls.shape[:2], dtype=bool, device=device)
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
return anchor_matched_targets, valid_mask

target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
target_cls = target_cls.long().clamp(0)

Expand Down Expand Up @@ -261,8 +286,8 @@ def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tens
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
align_cls = align_cls * normalize_term * valid_mask[:, :, None]

return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
return anchor_matched_targets, valid_mask


class Vec2Box:
Expand Down

0 comments on commit e53ff09

Please sign in to comment.