diff --git a/yolo/utils/bounding_box_utils.py b/yolo/utils/bounding_box_utils.py index ca94286..63896e3 100644 --- a/yolo/utils/bounding_box_utils.py +++ b/yolo/utils/bounding_box_utils.py @@ -2,7 +2,6 @@ from typing import Dict, List, Optional, Tuple, Union import torch -import torch.nn.functional as F from einops import rearrange from torch import Tensor, tensor from torchmetrics.detection import MeanAveragePrecision @@ -235,7 +234,8 @@ def ensure_one_anchor(self, target_matrix: Tensor, topk_mask: tensor) -> Tensor: topk_mask [batch x targets x anchors]: A boolean mask indicating the updated top-k scores' positions. """ values, indices = target_matrix.max(dim=-1) - best_anchor_mask = F.one_hot(indices, target_matrix.size(-1)) + best_anchor_mask = torch.zeros_like(target_matrix, dtype=torch.bool) + best_anchor_mask.scatter_(-1, index=indices[..., None], src=~best_anchor_mask) matched_anchor_num = torch.sum(topk_mask, dim=-1) target_without_anchor = (matched_anchor_num == 0) & (values > 0) topk_mask = torch.where(target_without_anchor[..., None], best_anchor_mask, topk_mask) @@ -256,10 +256,12 @@ def filter_duplicates(self, iou_mat: Tensor, topk_mask: Tensor): """ duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1]) masked_iou_mat = topk_mask * iou_mat - max_idx = F.one_hot(masked_iou_mat.argmax(1), topk_mask.size(1)).permute(0, 2, 1) - topk_mask = torch.where(duplicates, max_idx, topk_mask) - unique_indices = topk_mask.argmax(dim=1) - return unique_indices[..., None], topk_mask.sum(1), topk_mask + best_indices = masked_iou_mat.argmax(1)[:, None, :] + best_target_mask = torch.zeros_like(duplicates, dtype=torch.bool) + best_target_mask.scatter_(1, index=best_indices, src=~best_target_mask) + topk_mask = torch.where(duplicates, best_target_mask, topk_mask) + unique_indices = topk_mask.to(torch.uint8).argmax(dim=1) + return unique_indices[..., None], topk_mask.any(dim=1), topk_mask def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]: """Matches each target to the most suitable anchor. @@ -317,8 +319,9 @@ def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tens unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask) align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4)) - align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1) - align_cls = F.one_hot(align_cls, self.class_num) + align_cls_indices = torch.gather(target_cls, 1, unique_indices) + align_cls = torch.zeros_like(align_cls_indices, dtype=torch.bool).repeat(1, 1, self.class_num) + align_cls.scatter_(-1, index=align_cls_indices, src=~align_cls) # normalize class ditribution iou_mat *= topk_mask @@ -329,7 +332,7 @@ def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tens normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices) align_cls = align_cls * normalize_term * valid_mask[:, :, None] anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1) - return anchor_matched_targets, valid_mask.bool() + return anchor_matched_targets, valid_mask class Vec2Box: