Skip to content

Commit

Permalink
🔨 [Update] F.one_hot calls in BoxMatcher
Browse files Browse the repository at this point in the history
to a more efficient solution, without using torch.nn.functional.

torch.nn.functional.one_hot always returns a long tensor, consuming a lot of memory for tensors, which are only used as masks.
  • Loading branch information
Adamusen authored Nov 28, 2024
1 parent 94aa2e7 commit 48bcb9b
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions yolo/utils/bounding_box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor, tensor
from torchmetrics.detection import MeanAveragePrecision
Expand Down Expand Up @@ -235,7 +234,8 @@ def ensure_one_anchor(self, target_matrix: Tensor, topk_mask: tensor) -> Tensor:
topk_mask [batch x targets x anchors]: A boolean mask indicating the updated top-k scores' positions.
"""
values, indices = target_matrix.max(dim=-1)
best_anchor_mask = F.one_hot(indices, target_matrix.size(-1))
best_anchor_mask = torch.zeros_like(target_matrix, dtype=torch.bool)
best_anchor_mask.scatter_(-1, index=indices[..., None], src=~best_anchor_mask)
matched_anchor_num = torch.sum(topk_mask, dim=-1)
target_without_anchor = (matched_anchor_num == 0) & (values > 0)
topk_mask = torch.where(target_without_anchor[..., None], best_anchor_mask, topk_mask)
Expand All @@ -256,10 +256,12 @@ def filter_duplicates(self, iou_mat: Tensor, topk_mask: Tensor):
"""
duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
masked_iou_mat = topk_mask * iou_mat
max_idx = F.one_hot(masked_iou_mat.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
topk_mask = torch.where(duplicates, max_idx, topk_mask)
unique_indices = topk_mask.argmax(dim=1)
return unique_indices[..., None], topk_mask.sum(1), topk_mask
best_indices = masked_iou_mat.argmax(1)[:, None, :]
best_target_mask = torch.zeros_like(duplicates, dtype=torch.bool)
best_target_mask.scatter_(1, index=best_indices, src=~best_target_mask)
topk_mask = torch.where(duplicates, best_target_mask, topk_mask)
unique_indices = topk_mask.to(torch.uint8).argmax(dim=1)
return unique_indices[..., None], topk_mask.any(dim=1), topk_mask

def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
"""Matches each target to the most suitable anchor.
Expand Down Expand Up @@ -317,8 +319,9 @@ def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tens
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask)

align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
align_cls = F.one_hot(align_cls, self.class_num)
align_cls_indices = torch.gather(target_cls, 1, unique_indices)
align_cls = torch.zeros_like(align_cls_indices, dtype=torch.bool).repeat(1, 1, self.class_num)
align_cls.scatter_(-1, index=align_cls_indices, src=~align_cls)

# normalize class ditribution
iou_mat *= topk_mask
Expand All @@ -329,7 +332,7 @@ def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tens
normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
align_cls = align_cls * normalize_term * valid_mask[:, :, None]
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
return anchor_matched_targets, valid_mask.bool()
return anchor_matched_targets, valid_mask


class Vec2Box:
Expand Down

0 comments on commit 48bcb9b

Please sign in to comment.