diff --git a/yolo/utils/bounding_box_utils.py b/yolo/utils/bounding_box_utils.py index 12d95c5..1d714d3 100644 --- a/yolo/utils/bounding_box_utils.py +++ b/yolo/utils/bounding_box_utils.py @@ -222,12 +222,37 @@ def filter_duplicates(self, target_matrix: Tensor): return unique_indices[..., None] def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]: - """ - 1. For each anchor prediction, find the highest suitability targets - 2. Select the targets - 2. Noramlize the class probilities of targets + """Matches each target to the most suitable anchor. + 1. For each anchor prediction, find the highest suitability targets. + 2. Match target to the best anchor. + 3. Noramlize the class probilities of targets. + + Args: + target: The ground truth class and bounding box information + as tensor of size [batch x targets x 5]. + predict: Tuple of predicted class and bounding box tensors. + Class tensor is of size [batch x anchors x class] + Bounding box tensor is of size [batch x anchors x 4]. + + Returns: + anchor_matched_targets: Tensor of size [batch x anchors x (class + 4)]. + A tensor assigning each target/gt to the best fitting anchor. + The class probabilities are normalized. + valid_mask: Bool tensor of shape [batch x anchors]. + True if a anchor has a target/gt assigned to it. """ predict_cls, predict_bbox = predict + + # return if target has no gt information. + n_targets = target.shape[1] + if n_targets == 0: + device = predict_bbox.device + align_cls = torch.zeros_like(predict_cls, device=device) + align_bbox = torch.zeros_like(predict_bbox, device=device) + valid_mask = torch.zeros(predict_cls.shape[:2], dtype=bool, device=device) + anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1) + return anchor_matched_targets, valid_mask + target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B target_cls = target_cls.long().clamp(0) @@ -261,8 +286,8 @@ def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tens normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices) align_cls = align_cls * normalize_term * valid_mask[:, :, None] - - return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool() + anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1) + return anchor_matched_targets, valid_mask class Vec2Box: