From dd53dfb41638196f5992db554481cbd5dfc301eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Kun=C3=A1k?= <38215643+Adamusen@users.noreply.github.com> Date: Thu, 7 Nov 2024 12:06:53 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20[Fix]=20wrong=20filter=5Fduplica?= =?UTF-8?q?tes=20argument?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The filter_duplicates() function should receive the topk_targets instead of the iou_mat as its first argument. --- yolo/utils/bounding_box_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yolo/utils/bounding_box_utils.py b/yolo/utils/bounding_box_utils.py index 20539be..54eaae1 100644 --- a/yolo/utils/bounding_box_utils.py +++ b/yolo/utils/bounding_box_utils.py @@ -278,7 +278,7 @@ def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tens topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk) # delete one anchor pred assign to mutliple gts - unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask) + unique_indices, valid_mask, topk_mask = self.filter_duplicates(topk_targets, topk_mask) align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4)) align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)