From fbc11d0eebde3bda3b6c5ea3ebc83c7c8e0ac1ce Mon Sep 17 00:00:00 2001 From: Olof Harrysson Date: Mon, 13 Dec 2021 11:33:47 +0100 Subject: [PATCH 1/2] fix: Properly checks if ground truths are empty --- tests/detection/test_map.py | 15 +++++++++++++++ torchmetrics/detection/map.py | 4 ++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index 83d850c8d4b..f791771b64b 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -214,6 +214,21 @@ def test_empty_preds(): ) metric.compute() +@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") +def test_empty_ground_truths(): + """Test empty ground truths.""" + metric = MAP() + + metric.update( + [ + dict(boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), scores=torch.Tensor([0.5]), labels=torch.IntTensor([4])), + ], + [ + dict(boxes=torch.Tensor([]), labels=torch.IntTensor([])), + ], + ) + metric.compute() + _gpu_test_condition = not torch.cuda.is_available() diff --git a/torchmetrics/detection/map.py b/torchmetrics/detection/map.py index d40370857d7..39c16ec5a2e 100644 --- a/torchmetrics/detection/map.py +++ b/torchmetrics/detection/map.py @@ -356,7 +356,7 @@ def _compute_iou(self, id: int, class_id: int, max_det: int) -> Tensor: det = self.detection_boxes[id] gt_label_mask = self.groundtruth_labels[id] == class_id det_label_mask = self.detection_labels[id] == class_id - if len(det_label_mask) == 0 or len(det_label_mask) == 0: + if len(gt_label_mask) == 0 or len(det_label_mask) == 0: return Tensor([]) gt = gt[gt_label_mask] det = det[det_label_mask] @@ -396,7 +396,7 @@ def _evaluate_image( det = self.detection_boxes[id] gt_label_mask = self.groundtruth_labels[id] == class_id det_label_mask = self.detection_labels[id] == class_id - if len(det_label_mask) == 0 or len(det_label_mask) == 0: + if len(gt_label_mask) == 0 or len(det_label_mask) == 0: return None gt = gt[gt_label_mask] det = det[det_label_mask] From 591fb550490b606b943e03eec2b1a72f8e2d7dce Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Dec 2021 11:06:45 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/detection/test_map.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index f791771b64b..8b02ca160f6 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -214,6 +214,7 @@ def test_empty_preds(): ) metric.compute() + @pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") def test_empty_ground_truths(): """Test empty ground truths.""" @@ -221,7 +222,11 @@ def test_empty_ground_truths(): metric.update( [ - dict(boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), scores=torch.Tensor([0.5]), labels=torch.IntTensor([4])), + dict( + boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), + scores=torch.Tensor([0.5]), + labels=torch.IntTensor([4]), + ), ], [ dict(boxes=torch.Tensor([]), labels=torch.IntTensor([])),